Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modified client for ise and added tests #33

Merged
merged 1 commit into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
nac_collector/__pycache__
*.json
*.yaml
tmp/
.envrc
.DS_Store
.pylintrc

# pyenv
.python-version
.python-version
__pycache__/
46 changes: 39 additions & 7 deletions nac_collector/cisco_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_request(self, url):
response = self.session.get(
url, verify=self.ssl_verify, timeout=self.timeout
)

except requests.exceptions.Timeout:
self.logger.error(
"GET %s timed out after %s seconds.", url, self.timeout
Expand Down Expand Up @@ -178,7 +179,7 @@ def log_response(self, endpoint, response):
"""
Logs the response from a GET request.

Args:
Parameters:
endpoint (str): The endpoint the request was sent to.
response (Response): The response from the request.
"""
Expand All @@ -195,17 +196,48 @@ def log_response(self, endpoint, response):
response.status_code,
)

def write_to_json(self, final_dict, solution):
def fetch_data(self, endpoint):
"""
Fetch data from a specified endpoint.

Parameters:
endpoint (str): Endpoint URL.

Returns:
data (dict): The JSON content of the response or None if an error occurred.
"""
# Make the request to the given endpoint
response = self.get_request(self.base_url + endpoint)
if response:
try:
# Get the JSON content of the response
data = response.json()
self.logger.info(
"GET %s succeeded with status code %s",
endpoint,
response.status_code,
)
return data
except ValueError:
self.logger.error(
"Failed to decode JSON from response for endpoint: %s", endpoint
)
return None
else:
self.logger.error("No valid response received for endpoint: %s", endpoint)
return None

def write_to_json(self, final_dict, output):
"""
Writes the final dictionary to a JSON file.

Args:
Parameters:
final_dict (dict): The final dictionary to write to the file.
solution (str): The solution name to use as the filename.
output (str): Filename
"""
with open(f"{solution}.json", "w", encoding="utf-8") as f:
with open(output, "w", encoding="utf-8") as f:
json.dump(final_dict, f, indent=4)
self.logger.info("Data written to %s.json", solution)
self.logger.info("Data written to %s", output)

@staticmethod
def create_endpoint_dict(endpoint):
Expand All @@ -216,7 +248,7 @@ def create_endpoint_dict(endpoint):
The value dictionary contains "items" and "children" as empty lists and dictionaries,
respectively, and "endpoint" as the endpoint's endpoint.

Args:
Parameters:
endpoint (dict): The endpoint to create a dictionary for. It should contain "name"
and "endpoint" keys.

Expand Down
213 changes: 116 additions & 97 deletions nac_collector/cisco_client_ise.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging

import click
import requests
import urllib3

from nac_collector.cisco_client import CiscoClient

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

logger = logging.getLogger("main")

# Suppress urllib3 warnings
logging.getLogger("urllib3").setLevel(logging.ERROR)


class CiscoClientISE(CiscoClient):
"""
Expand Down Expand Up @@ -98,6 +101,7 @@ def get_from_endpoints(self, endpoints_yaml_file):
Returns:
dict: The final dictionary containing the data retrieved from the endpoints.
"""

# Load endpoints from the YAML file
logger.info("Loading endpoints from %s", endpoints_yaml_file)
with open(endpoints_yaml_file, "r", encoding="utf-8") as f:
Expand All @@ -109,108 +113,123 @@ def get_from_endpoints(self, endpoints_yaml_file):
# Initialize an empty list for endpoint with children (%v in endpoint['endpoint'])
children_endpoints = []

# Iterate through the endpoints
for endpoint in endpoints:
if all(x not in endpoint["endpoint"] for x in ["%v", "%i"]):
endpoint_dict = CiscoClient.create_endpoint_dict(endpoint)
# Iterate over all endpoints
with click.progressbar(endpoints, label="Processing endpoints") as endpoint_bar:
for endpoint in endpoint_bar:
logger.info("Processing endpoint: %s", endpoint)

response = self.get_request(self.base_url + endpoint["endpoint"])
if all(x not in endpoint["endpoint"] for x in ["%v", "%i"]):
endpoint_dict = CiscoClient.create_endpoint_dict(endpoint)

# Get the JSON content of the response
data = response.json()
# License API returns a list of dictionaries
if isinstance(data, list):
endpoint_dict[endpoint["name"]].append(
{"data": data, "endpoint": endpoint["endpoint"]}
)
data = self.fetch_data(endpoint["endpoint"])

elif data.get("response"):
for i in data.get("response"):
if data is None:
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
{"data": {}, "endpoint": endpoint["endpoint"]}
)
# Pagination for ERS API results
elif data.get("SearchResult"):
ers_data = self.process_ers_api_results(data)
for i in ers_data:

# License API returns a list of dictionaries
elif isinstance(data, list):
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
{"data": data, "endpoint": endpoint["endpoint"]}
)
# Check if response is empty list
elif data.get("response") == []:
endpoint_dict[endpoint["name"]].append(
{"data": {}, "endpoint": endpoint["endpoint"]}
)

# Save results to dictionary
final_dict.update(endpoint_dict)

self.log_response(endpoint["endpoint"], response)

elif "%v" in endpoint["endpoint"]:
children_endpoints.append(endpoint)

# Iterate through the children endpoints
for endpoint in children_endpoints:
parent_endpoint = endpoint["endpoint"].split("/%v")[0]

# Iterate over the dictionary
for _, value in final_dict.items():
index = 0
# Iterate over the items in final_dict[parent_endpoint]
for item in value:
if parent_endpoint == "/".join(
item.get("endpoint").split("/")[:-1]
):
# Initialize an empty list for parent endpoint ids
parent_endpoint_ids = []

# Add the item's id to the list
try:
parent_endpoint_ids.append(item["data"]["id"])
except KeyError:
continue
# Iterate over the parent endpoint ids
for id_ in parent_endpoint_ids:
# Replace '%v' in the endpoint with the id
new_endpoint = endpoint["endpoint"].replace("%v", str(id_))
# Send a GET request to the new endpoint
response = self.get_request(self.base_url + new_endpoint)
# Get the JSON content of the response
data = response.json()

if data.get("response"):
for i in data.get("response"):
# Check if the key exists
if "children" not in value[index]:
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"] = {}
# Check if the key exists
if endpoint["name"] not in value[index]["children"]:
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"][endpoint["name"]] = []

value[index]["children"][endpoint["name"]].append(
{
"data": i,
"endpoint": new_endpoint
+ "/"
+ self.get_id_value(i),
}
)
self.log_response(new_endpoint, response)

index += 1

elif data.get("response"):
for i in data.get("response"):
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
)
# Pagination for ERS API results
elif data.get("SearchResult"):

ers_data = self.process_ers_api_results(data)

for i in ers_data:
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
)

# Save results to dictionary
final_dict.update(endpoint_dict)

elif "%v" in endpoint["endpoint"]:
children_endpoints.append(endpoint)

# Iterate over all children endpoints
with click.progressbar(
children_endpoints, label="Processing children endpoints"
) as children_endpoint_bar:
for endpoint in children_endpoint_bar:
logger.info("Processing children endpoint: %s", endpoint)

parent_endpoint = endpoint["endpoint"].split("/%v")[0]

# Iterate over the dictionary
for _, value in final_dict.items():
index = 0
# Iterate over the items in final_dict[parent_endpoint]
for item in value:
if parent_endpoint == "/".join(
item.get("endpoint").split("/")[:-1]
):
# Initialize an empty list for parent endpoint ids
parent_endpoint_ids = []

# Add the item's id to the list
try:
parent_endpoint_ids.append(item["data"]["id"])
except KeyError:
continue
# Iterate over the parent endpoint ids
for id_ in parent_endpoint_ids:
# Replace '%v' in the endpoint with the id
new_endpoint = endpoint["endpoint"].replace(
"%v", str(id_)
)

data = self.fetch_data(new_endpoint)

# Get the JSON content of the response
if data is None:
continue
elif data.get("response"):
for i in data.get("response"):
# Check if the key exists
if "children" not in value[index]:
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"] = {}
# Check if the key exists
if (
endpoint["name"]
not in value[index]["children"]
):
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"][
endpoint["name"]
] = []

value[index]["children"][
endpoint["name"]
].append(
{
"data": i,
"endpoint": new_endpoint
+ "/"
+ self.get_id_value(i),
}
)

index += 1
return final_dict

def process_ers_api_results(self, data):
Expand Down Expand Up @@ -252,7 +271,7 @@ def get_id_value(i):
"""
Attempts to get the 'id' or 'name' value from a dictionary.

Args:
Parameters:
i (dict): The dictionary to get the 'id' or 'name' value from.

Returns:
Expand Down
Loading