From e10cdaa3e1fc096c887cc93450f8face28fbf0d2 Mon Sep 17 00:00:00 2001 From: Kuba Mazurkiewicz Date: Wed, 3 Jul 2024 08:45:12 +0200 Subject: [PATCH] modified client for ise and added tests --- .gitignore | 5 +- nac_collector/cisco_client.py | 46 ++++- nac_collector/cisco_client_ise.py | 213 +++++++++++---------- nac_collector/github_repo_wrapper.py | 81 ++++---- nac_collector/main.py | 27 ++- pyproject.toml | 3 +- tests/ise/integration/__init__.py | 0 tests/ise/integration/test_integration.py | 112 +++++++++++ tests/ise/unit/__init__.py | 0 tests/ise/unit/test_authentication.py | 40 ++++ tests/ise/unit/test_ers_api_pagination.py | 132 +++++++++++++ tests/ise/unit/test_fetch_data.py | 54 ++++++ tests/ise/unit/test_id_value_extraction.py | 29 +++ tests/ise/unit/test_initialization.py | 24 +++ 14 files changed, 616 insertions(+), 150 deletions(-) create mode 100644 tests/ise/integration/__init__.py create mode 100644 tests/ise/integration/test_integration.py create mode 100644 tests/ise/unit/__init__.py create mode 100644 tests/ise/unit/test_authentication.py create mode 100644 tests/ise/unit/test_ers_api_pagination.py create mode 100644 tests/ise/unit/test_fetch_data.py create mode 100644 tests/ise/unit/test_id_value_extraction.py create mode 100644 tests/ise/unit/test_initialization.py diff --git a/.gitignore b/.gitignore index 10f131b..4565926 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,10 @@ -nac_collector/__pycache__ *.json +*.yaml tmp/ .envrc .DS_Store .pylintrc # pyenv -.python-version \ No newline at end of file +.python-version +__pycache__/ \ No newline at end of file diff --git a/nac_collector/cisco_client.py b/nac_collector/cisco_client.py index e5a64c1..3fdfda0 100644 --- a/nac_collector/cisco_client.py +++ b/nac_collector/cisco_client.py @@ -93,6 +93,7 @@ def get_request(self, url): response = self.session.get( url, verify=self.ssl_verify, timeout=self.timeout ) + except requests.exceptions.Timeout: self.logger.error( "GET %s timed out after %s seconds.", url, self.timeout @@ -178,7 +179,7 @@ def log_response(self, endpoint, response): """ Logs the response from a GET request. - Args: + Parameters: endpoint (str): The endpoint the request was sent to. response (Response): The response from the request. """ @@ -195,17 +196,48 @@ def log_response(self, endpoint, response): response.status_code, ) - def write_to_json(self, final_dict, solution): + def fetch_data(self, endpoint): + """ + Fetch data from a specified endpoint. + + Parameters: + endpoint (str): Endpoint URL. + + Returns: + data (dict): The JSON content of the response or None if an error occurred. + """ + # Make the request to the given endpoint + response = self.get_request(self.base_url + endpoint) + if response: + try: + # Get the JSON content of the response + data = response.json() + self.logger.info( + "GET %s succeeded with status code %s", + endpoint, + response.status_code, + ) + return data + except ValueError: + self.logger.error( + "Failed to decode JSON from response for endpoint: %s", endpoint + ) + return None + else: + self.logger.error("No valid response received for endpoint: %s", endpoint) + return None + + def write_to_json(self, final_dict, output): """ Writes the final dictionary to a JSON file. - Args: + Parameters: final_dict (dict): The final dictionary to write to the file. - solution (str): The solution name to use as the filename. + output (str): Filename """ - with open(f"{solution}.json", "w", encoding="utf-8") as f: + with open(output, "w", encoding="utf-8") as f: json.dump(final_dict, f, indent=4) - self.logger.info("Data written to %s.json", solution) + self.logger.info("Data written to %s", output) @staticmethod def create_endpoint_dict(endpoint): @@ -216,7 +248,7 @@ def create_endpoint_dict(endpoint): The value dictionary contains "items" and "children" as empty lists and dictionaries, respectively, and "endpoint" as the endpoint's endpoint. - Args: + Parameters: endpoint (dict): The endpoint to create a dictionary for. It should contain "name" and "endpoint" keys. diff --git a/nac_collector/cisco_client_ise.py b/nac_collector/cisco_client_ise.py index 27a4729..ad8f67e 100644 --- a/nac_collector/cisco_client_ise.py +++ b/nac_collector/cisco_client_ise.py @@ -1,14 +1,17 @@ import logging +import click import requests import urllib3 from nac_collector.cisco_client import CiscoClient urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) - logger = logging.getLogger("main") +# Suppress urllib3 warnings +logging.getLogger("urllib3").setLevel(logging.ERROR) + class CiscoClientISE(CiscoClient): """ @@ -98,6 +101,7 @@ def get_from_endpoints(self, endpoints_yaml_file): Returns: dict: The final dictionary containing the data retrieved from the endpoints. """ + # Load endpoints from the YAML file logger.info("Loading endpoints from %s", endpoints_yaml_file) with open(endpoints_yaml_file, "r", encoding="utf-8") as f: @@ -109,108 +113,123 @@ def get_from_endpoints(self, endpoints_yaml_file): # Initialize an empty list for endpoint with children (%v in endpoint['endpoint']) children_endpoints = [] - # Iterate through the endpoints - for endpoint in endpoints: - if all(x not in endpoint["endpoint"] for x in ["%v", "%i"]): - endpoint_dict = CiscoClient.create_endpoint_dict(endpoint) + # Iterate over all endpoints + with click.progressbar(endpoints, label="Processing endpoints") as endpoint_bar: + for endpoint in endpoint_bar: + logger.info("Processing endpoint: %s", endpoint) - response = self.get_request(self.base_url + endpoint["endpoint"]) + if all(x not in endpoint["endpoint"] for x in ["%v", "%i"]): + endpoint_dict = CiscoClient.create_endpoint_dict(endpoint) - # Get the JSON content of the response - data = response.json() - # License API returns a list of dictionaries - if isinstance(data, list): - endpoint_dict[endpoint["name"]].append( - {"data": data, "endpoint": endpoint["endpoint"]} - ) + data = self.fetch_data(endpoint["endpoint"]) - elif data.get("response"): - for i in data.get("response"): + if data is None: endpoint_dict[endpoint["name"]].append( - { - "data": i, - "endpoint": endpoint["endpoint"] - + "/" - + self.get_id_value(i), - } + {"data": {}, "endpoint": endpoint["endpoint"]} ) - # Pagination for ERS API results - elif data.get("SearchResult"): - ers_data = self.process_ers_api_results(data) - for i in ers_data: + + # License API returns a list of dictionaries + elif isinstance(data, list): endpoint_dict[endpoint["name"]].append( - { - "data": i, - "endpoint": endpoint["endpoint"] - + "/" - + self.get_id_value(i), - } + {"data": data, "endpoint": endpoint["endpoint"]} ) - # Check if response is empty list - elif data.get("response") == []: - endpoint_dict[endpoint["name"]].append( - {"data": {}, "endpoint": endpoint["endpoint"]} - ) - - # Save results to dictionary - final_dict.update(endpoint_dict) - - self.log_response(endpoint["endpoint"], response) - - elif "%v" in endpoint["endpoint"]: - children_endpoints.append(endpoint) - - # Iterate through the children endpoints - for endpoint in children_endpoints: - parent_endpoint = endpoint["endpoint"].split("/%v")[0] - - # Iterate over the dictionary - for _, value in final_dict.items(): - index = 0 - # Iterate over the items in final_dict[parent_endpoint] - for item in value: - if parent_endpoint == "/".join( - item.get("endpoint").split("/")[:-1] - ): - # Initialize an empty list for parent endpoint ids - parent_endpoint_ids = [] - - # Add the item's id to the list - try: - parent_endpoint_ids.append(item["data"]["id"]) - except KeyError: - continue - # Iterate over the parent endpoint ids - for id_ in parent_endpoint_ids: - # Replace '%v' in the endpoint with the id - new_endpoint = endpoint["endpoint"].replace("%v", str(id_)) - # Send a GET request to the new endpoint - response = self.get_request(self.base_url + new_endpoint) - # Get the JSON content of the response - data = response.json() - - if data.get("response"): - for i in data.get("response"): - # Check if the key exists - if "children" not in value[index]: - # If the key doesn't exist, create it and initialize it as an empty list - value[index]["children"] = {} - # Check if the key exists - if endpoint["name"] not in value[index]["children"]: - # If the key doesn't exist, create it and initialize it as an empty list - value[index]["children"][endpoint["name"]] = [] - - value[index]["children"][endpoint["name"]].append( - { - "data": i, - "endpoint": new_endpoint - + "/" - + self.get_id_value(i), - } - ) - self.log_response(new_endpoint, response) - - index += 1 + + elif data.get("response"): + for i in data.get("response"): + endpoint_dict[endpoint["name"]].append( + { + "data": i, + "endpoint": endpoint["endpoint"] + + "/" + + self.get_id_value(i), + } + ) + # Pagination for ERS API results + elif data.get("SearchResult"): + + ers_data = self.process_ers_api_results(data) + + for i in ers_data: + endpoint_dict[endpoint["name"]].append( + { + "data": i, + "endpoint": endpoint["endpoint"] + + "/" + + self.get_id_value(i), + } + ) + + # Save results to dictionary + final_dict.update(endpoint_dict) + + elif "%v" in endpoint["endpoint"]: + children_endpoints.append(endpoint) + + # Iterate over all children endpoints + with click.progressbar( + children_endpoints, label="Processing children endpoints" + ) as children_endpoint_bar: + for endpoint in children_endpoint_bar: + logger.info("Processing children endpoint: %s", endpoint) + + parent_endpoint = endpoint["endpoint"].split("/%v")[0] + + # Iterate over the dictionary + for _, value in final_dict.items(): + index = 0 + # Iterate over the items in final_dict[parent_endpoint] + for item in value: + if parent_endpoint == "/".join( + item.get("endpoint").split("/")[:-1] + ): + # Initialize an empty list for parent endpoint ids + parent_endpoint_ids = [] + + # Add the item's id to the list + try: + parent_endpoint_ids.append(item["data"]["id"]) + except KeyError: + continue + # Iterate over the parent endpoint ids + for id_ in parent_endpoint_ids: + # Replace '%v' in the endpoint with the id + new_endpoint = endpoint["endpoint"].replace( + "%v", str(id_) + ) + + data = self.fetch_data(new_endpoint) + + # Get the JSON content of the response + if data is None: + continue + elif data.get("response"): + for i in data.get("response"): + # Check if the key exists + if "children" not in value[index]: + # If the key doesn't exist, create it and initialize it as an empty list + value[index]["children"] = {} + # Check if the key exists + if ( + endpoint["name"] + not in value[index]["children"] + ): + # If the key doesn't exist, create it and initialize it as an empty list + value[index]["children"][ + endpoint["name"] + ] = [] + + value[index]["children"][ + endpoint["name"] + ].append( + { + "data": i, + "endpoint": new_endpoint + + "/" + + self.get_id_value(i), + } + ) + + index += 1 return final_dict def process_ers_api_results(self, data): @@ -252,7 +271,7 @@ def get_id_value(i): """ Attempts to get the 'id' or 'name' value from a dictionary. - Args: + Parameters: i (dict): The dictionary to get the 'id' or 'name' value from. Returns: diff --git a/nac_collector/github_repo_wrapper.py b/nac_collector/github_repo_wrapper.py index 712fac7..98b6182 100644 --- a/nac_collector/github_repo_wrapper.py +++ b/nac_collector/github_repo_wrapper.py @@ -2,6 +2,7 @@ import os import shutil +import click from git import Repo from ruamel.yaml import YAML @@ -84,45 +85,51 @@ def get_definitions(self): self.logger.info("Inspecting YAML files in %s", definitions_dir) endpoints = [] endpoints_dict = [] + for root, _, files in os.walk(definitions_dir): - for file in files: - if file.endswith(".yaml"): - with open(os.path.join(root, file), "r", encoding="utf-8") as f: - data = self.yaml.load(f) - if "rest_endpoint" in data: - self.logger.info( - "Found rest_endpoint: %s in file: %s", - data["rest_endpoint"], - file, - ) - endpoints.append(data["rest_endpoint"]) - # for SDWAN feature_device_templates - if file.split(".yaml")[0] == "feature_device_template": - endpoints_dict.append( - { - "name": file.split(".yaml")[0], - "endpoint": "/template/device/object/%i", - } - ) - else: - endpoints_dict.append( - { - "name": file.split(".yaml")[0], - "endpoint": data["rest_endpoint"], - } + # Iterate over all endpoints + with click.progressbar( + files, label="Processing terraform provider definitions" + ) as files_bar: + for file in files_bar: + + if file.endswith(".yaml"): + with open(os.path.join(root, file), "r", encoding="utf-8") as f: + data = self.yaml.load(f) + if "rest_endpoint" in data: + self.logger.info( + "Found rest_endpoint: %s in file: %s", + data["rest_endpoint"], + file, ) - - # for SDWAN feature_templates - if root.endswith("feature_templates"): - self.logger.debug("Found feature_templates directory") - endpoints.append("/template/feature/object/%i") - endpoints_dict.append( - { - "name": "feature_templates", - "endpoint": "/template/feature/object/%i", - } - ) - break + endpoints.append(data["rest_endpoint"]) + # for SDWAN feature_device_templates + if file.split(".yaml")[0] == "feature_device_template": + endpoints_dict.append( + { + "name": file.split(".yaml")[0], + "endpoint": "/template/device/object/%i", + } + ) + else: + endpoints_dict.append( + { + "name": file.split(".yaml")[0], + "endpoint": data["rest_endpoint"], + } + ) + + # for SDWAN feature_templates + if root.endswith("feature_templates"): + self.logger.debug("Found feature_templates directory") + endpoints.append("/template/feature/object/%i") + endpoints_dict.append( + { + "name": "feature_templates", + "endpoint": "/template/feature/object/%i", + } + ) + break # Save endpoints to a YAML file filename = f"endpoints_{self.solution}.yaml" diff --git a/nac_collector/main.py b/nac_collector/main.py index de7ff91..f6df56f 100644 --- a/nac_collector/main.py +++ b/nac_collector/main.py @@ -59,6 +59,16 @@ is_flag=True, help="Generate endpoint.yaml automatically using provider github repo", ) +@click.option( + "--endpoints-file", + "-e", + type=str, + default=None, + help="Path to the endpoints YAML file", +) +@click.option( + "--output", "-o", type=str, default=None, help="Path to the output json file" +) def cli( solution: str, username: str, @@ -66,6 +76,8 @@ def cli( url: str, verbose: bool, git_provider: bool, + endpoints_file: str, + output: str, ) -> None: """ Command Line Interface (CLI) function for the application. @@ -77,6 +89,9 @@ def cli( url (str): The URL of the server to connect to. verbose (bool): If True, detailed output will be printed to the console. git_provider (bool): If True, the solution will be fetched from a Git provider. + endpoints_file (str): Path to the endpoints YAML file. + output (str): Path to the output json file. + Returns: None @@ -99,6 +114,9 @@ def cli( ) wrapper.get_definitions() + endpoints_yaml_file = endpoints_file or f"endpoints_{solution.lower()}.yaml" + output_file = output or f"{solution.lower()}.json" + if solution == "SDWAN": client = CiscoClientSDWAN( username=username, @@ -115,9 +133,8 @@ def cli( print("Authentication failed. Exiting...") return - endpoints_yaml_file = f"endpoints_{solution.lower()}.yaml" final_dict = client.get_from_endpoints(endpoints_yaml_file) - client.write_to_json(final_dict, f"{solution.lower()}") + client.write_to_json(final_dict, output_file) elif solution == "ISE": client = CiscoClientISE( @@ -135,9 +152,8 @@ def cli( print("Authentication failed. Exiting...") return - endpoints_yaml_file = f"endpoints_{solution.lower()}.yaml" final_dict = client.get_from_endpoints(endpoints_yaml_file) - client.write_to_json(final_dict, f"{solution.lower()}") + client.write_to_json(final_dict, output_file) elif solution == "NDO": client = CiscoClientNDO( @@ -153,9 +169,8 @@ def cli( # Authenticate client.authenticate() - endpoints_yaml_file = f"endpoints_{solution.lower()}.yaml" final_dict = client.get_from_endpoints(endpoints_yaml_file) - client.write_to_json(final_dict, f"{solution.lower()}") + client.write_to_json(final_dict, output_file) else: pass diff --git a/pyproject.toml b/pyproject.toml index 88f6a7b..06f7433 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,4 +47,5 @@ sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" - +[tool.pytest.ini_options] +markers = ["unit", "integration"] diff --git a/tests/ise/integration/__init__.py b/tests/ise/integration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ise/integration/test_integration.py b/tests/ise/integration/test_integration.py new file mode 100644 index 0000000..7caa9b4 --- /dev/null +++ b/tests/ise/integration/test_integration.py @@ -0,0 +1,112 @@ +import json +from unittest.mock import Mock, patch + +import pytest + +# from nac_collector.main import cli +from nac_collector.cisco_client_ise import CiscoClientISE + +pytestmark = pytest.mark.integration + + +@pytest.fixture +def cisco_client(): + return CiscoClientISE( + username="test_user", + password="test_password", + base_url="https://example.com", + max_retries=3, + retry_after=1, + timeout=5, + ssl_verify=False, + ) + + +def test_cisco_client_ise_with_integration(cisco_client, tmpdir): + + def mock_get_request(url): + # Mock responses for specific API endpoints + mock_responses = { + "https://example.com/api/endpoint_1": { + "response": [ + { + "id": "id_1", + "name": "name_1", + "description": "name_1_description", + } + ] + }, + "https://example.com/api/endpoint_1/id_1/ch_endpoint_1": { + "response": [ + { + "id": "ch_id_1", + "name": "ch_name_1", + "description": "ch_name_1_description", + }, + { + "id": "ch_id_2", + "name": "ch_name_2", + "description": "ch_name_2_description", + }, + ] + }, + "https://example.com/api/endpoint_1/id_1/ch_endpoint_2": { + "response": [ + { + "id": "ch_id_3", + "name": "ch_name_3", + "description": "ch_name_3_description", + }, + { + "id": "ch_id_4", + "name": "ch_name_4", + "description": "ch_name_4_description", + }, + ] + }, + "https://example.com/api/endpoint_2": { + "response": [ + { + "id": "id_2", + "name": "name_2", + "description": "name_2_description", + }, + { + "id": "id_3", + "name": "name_3", + "description": "name_3_description", + }, + { + "id": "id_4", + "name": "name_4", + "description": "name_4_description", + }, + ] + }, + } + + if url in mock_responses: + return Mock(status_code=200, json=lambda: mock_responses[url]) + else: + raise ValueError(f"Unexpected URL in mock_get_request: {url}") + + # Patching get_request method with mock implementation + with patch.object(cisco_client, "get_request", side_effect=mock_get_request): + # Call the method to test + final_dict = cisco_client.get_from_endpoints( + "tests/ise/integration/fixtures/endpoints.yaml" + ) + + # Write final_dict to a temporary JSON file + output_file = tmpdir.join("ise.json") + cisco_client.write_to_json(final_dict, str(output_file)) + + # Compare the content of ise.json with expected data + expected_json_file = "tests/ise/integration/fixtures/ise.json" + with open(expected_json_file, "r") as f_expected, open( + str(output_file), "r" + ) as f_actual: + expected_data = json.load(f_expected) + actual_data = json.load(f_actual) + + assert actual_data == expected_data, "Output JSON data does not match expected" diff --git a/tests/ise/unit/__init__.py b/tests/ise/unit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/ise/unit/test_authentication.py b/tests/ise/unit/test_authentication.py new file mode 100644 index 0000000..4ab28c9 --- /dev/null +++ b/tests/ise/unit/test_authentication.py @@ -0,0 +1,40 @@ +from unittest.mock import Mock + +import pytest + +from nac_collector.cisco_client_ise import CiscoClientISE + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def cisco_client(): + return CiscoClientISE( + username="test_user", + password="test_password", + base_url="https://example.com", + max_retries=3, + retry_after=1, + timeout=5, + ssl_verify=False, + ) + + +def test_authenticate_success(mocker, cisco_client): + mock_response = Mock() + mock_response.status_code = 200 + mocker.patch("requests.get", return_value=mock_response) + + result = cisco_client.authenticate() + assert result is True + assert cisco_client.session is not None + + +def test_authenticate_failure(mocker, cisco_client): + mock_response = Mock() + mock_response.status_code != 200 + mocker.patch("requests.get", return_value=mock_response) + + result = cisco_client.authenticate() + assert result is False + assert cisco_client.session is None diff --git a/tests/ise/unit/test_ers_api_pagination.py b/tests/ise/unit/test_ers_api_pagination.py new file mode 100644 index 0000000..d89f692 --- /dev/null +++ b/tests/ise/unit/test_ers_api_pagination.py @@ -0,0 +1,132 @@ +from unittest.mock import Mock + +import pytest + +from nac_collector.cisco_client_ise import CiscoClientISE + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def cisco_client(): + return CiscoClientISE( + username="test_user", + password="test_password", + base_url="https://example.com", + max_retries=3, + retry_after=1, + timeout=5, + ssl_verify=False, + ) + + +def test_process_ers_api_results_no_pagination(mocker, cisco_client): + # Mocking response when there's no pagination + mock_response = mocker.Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "SearchResult": { + "resources": [ + {"link": {"href": "https://example.com/api/endpoint/1"}}, + {"link": {"href": "https://example.com/api/endpoint/2"}}, + ], + } + } + mocker.patch.object(cisco_client, "get_request", return_value=mock_response) + + # Call the method to test + data = cisco_client.process_ers_api_results(mock_response.json.return_value) + + # Assertions + assert len(data) == 2 # Total 2 resources without pagination + assert all( + isinstance(item, dict) for item in data + ) # All items should be dictionaries + + +def test_process_ers_api_results_with_pagination(mocker, cisco_client): + # Mocking get_request method + def mock_get_request(url): + mock_responses = { + "https://example.com/api/endpoint/1": { + "key": { + "id": "1", + "name": "name-1", + "attr1": "attr-1", + "attr2": "attr-2", + } + }, + "https://example.com/api/endpoint/2": { + "key": { + "id": "2", + "name": "name-2", + "attr1": "attr-3", + "attr2": "attr-4", + } + }, + "https://example.com/api/endpoint?size=1&page=1": { + "SearchResult": { + "resources": [ + { + "link": {"href": "https://example.com/api/endpoint/1"}, + "id": "1", + "name": "name-1", + "attr1": "attr-1", + "attr2": "attr-2", + } + ], + "nextPage": { + "href": "https://example.com/api/endpoint?size=1&page=2" + }, + } + }, + "https://example.com/api/endpoint?size=1&page=2": { + "SearchResult": { + "resources": [ + { + "link": {"href": "https://example.com/api/endpoint/2"}, + "id": "2", + "name": "name-2", + "attr1": "attr-3", + "attr2": "attr-4", + } + ], + "previousPage": { + "href": "https://example.com/api/endpoint?size=1&page=1" + }, + } + }, + } + + if url in mock_responses: + return Mock(status_code=200, json=lambda: mock_responses[url]) + else: + raise ValueError(f"Unexpected URL in mock_get_request: {url}") + + mocker.patch.object(cisco_client, "get_request", side_effect=mock_get_request) + + # Call the method to test + data = [] + response_url = "https://example.com/api/endpoint?size=1&page=1" + response_data = mock_get_request(response_url).json() + data.extend(cisco_client.process_ers_api_results(response_data)) + + # Assertions + assert len(data) == 2 # Total 2 resources with pagination + for item in data: + assert isinstance(item, dict) + assert "id" in item + assert "name" in item + assert "attr1" in item + assert "attr2" in item + + # Specific assertions for the first and second elements + assert data[0]["id"] == "1" + assert data[0]["name"] == "name-1" + assert data[0]["attr1"] == "attr-1" + assert data[0]["attr2"] == "attr-2" + + assert data[1]["id"] == "2" + assert data[1]["name"] == "name-2" + assert data[1]["attr1"] == "attr-3" + assert data[1]["attr2"] == "attr-4" diff --git a/tests/ise/unit/test_fetch_data.py b/tests/ise/unit/test_fetch_data.py new file mode 100644 index 0000000..738f7a6 --- /dev/null +++ b/tests/ise/unit/test_fetch_data.py @@ -0,0 +1,54 @@ +import pytest + +from nac_collector.cisco_client_ise import CiscoClientISE + +pytestmark = pytest.mark.unit + + +@pytest.fixture +def cisco_client(): + # Mocked CiscoClientISE instance for testing + return CiscoClientISE( + username="test_user", + password="test_password", + base_url="https://example.com", + max_retries=3, + retry_after=1, + timeout=5, + ssl_verify=False, + ) + + +def test_fetch_data_none_response(mocker, cisco_client): + # Mocking fetch_data to return None + mocker.patch.object(cisco_client, "fetch_data", return_value=None) + + endpoint = "/api/test_endpoint" + data = cisco_client.fetch_data(endpoint) + + assert data is None + + +def test_fetch_data_list_response(mocker, cisco_client): + # Mocking fetch_data to return a list of dictionaries + mock_data = [{"id": "1", "name": "Item 1"}, {"id": "2", "name": "Item 2"}] + mocker.patch.object(cisco_client, "fetch_data", return_value=mock_data) + + endpoint = "/api/test_endpoint" + data = cisco_client.fetch_data(endpoint) + + assert isinstance(data, list) + assert len(data) == 2 + assert data[0]["id"] == "1" + + +def test_fetch_data_single_dict_response(mocker, cisco_client): + # Mocking fetch_data to return a single dictionary + mock_data = {"id": "1", "name": "Single Item"} + mocker.patch.object(cisco_client, "fetch_data", return_value=mock_data) + + endpoint = "/api/test_endpoint" + data = cisco_client.fetch_data(endpoint) + + assert isinstance(data, dict) + assert data["name"] == "Single Item" diff --git a/tests/ise/unit/test_id_value_extraction.py b/tests/ise/unit/test_id_value_extraction.py new file mode 100644 index 0000000..0561fbe --- /dev/null +++ b/tests/ise/unit/test_id_value_extraction.py @@ -0,0 +1,29 @@ +import pytest + +from nac_collector.cisco_client_ise import CiscoClientISE + +pytestmark = pytest.mark.unit + + +def test_get_id_value_from_dict_with_id(): + data = {"id": "12345", "name": "Test Item"} + id_value = CiscoClientISE.get_id_value(data) + assert id_value == "12345" + + +def test_get_id_value_from_dict_with_rule_id(): + data = {"rule": {"id": "54321"}, "name": "Test Rule"} + id_value = CiscoClientISE.get_id_value(data) + assert id_value == "54321" + + +def test_get_id_value_from_dict_with_name(): + data = {"name": "Item Name"} + id_value = CiscoClientISE.get_id_value(data) + assert id_value == "Item Name" + + +def test_get_id_value_none(): + data = {} + id_value = CiscoClientISE.get_id_value(data) + assert id_value is None diff --git a/tests/ise/unit/test_initialization.py b/tests/ise/unit/test_initialization.py new file mode 100644 index 0000000..f580c67 --- /dev/null +++ b/tests/ise/unit/test_initialization.py @@ -0,0 +1,24 @@ +import pytest + +from nac_collector.cisco_client_ise import CiscoClientISE + +pytestmark = pytest.mark.unit + + +def test_initialization(): + client = CiscoClientISE( + username="test_user", + password="test_password", + base_url="https://example.com", + max_retries=3, + retry_after=1, + timeout=5, + ssl_verify=False, + ) + assert client.username == "test_user" + assert client.password == "test_password" + assert client.base_url == "https://example.com" + assert client.max_retries == 3 + assert client.retry_after == 1 + assert client.timeout == 5 + assert client.ssl_verify is False