Skip to content

Commit

Permalink
Modified client for ise and added tests (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
kuba-mazurkiewicz authored Jul 11, 2024
1 parent a3d4ea7 commit be4ed75
Show file tree
Hide file tree
Showing 14 changed files with 616 additions and 150 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
nac_collector/__pycache__
*.json
*.yaml
tmp/
.envrc
.DS_Store
.pylintrc

# pyenv
.python-version
.python-version
__pycache__/
46 changes: 39 additions & 7 deletions nac_collector/cisco_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def get_request(self, url):
response = self.session.get(
url, verify=self.ssl_verify, timeout=self.timeout
)

except requests.exceptions.Timeout:
self.logger.error(
"GET %s timed out after %s seconds.", url, self.timeout
Expand Down Expand Up @@ -178,7 +179,7 @@ def log_response(self, endpoint, response):
"""
Logs the response from a GET request.
Args:
Parameters:
endpoint (str): The endpoint the request was sent to.
response (Response): The response from the request.
"""
Expand All @@ -195,17 +196,48 @@ def log_response(self, endpoint, response):
response.status_code,
)

def write_to_json(self, final_dict, solution):
def fetch_data(self, endpoint):
"""
Fetch data from a specified endpoint.
Parameters:
endpoint (str): Endpoint URL.
Returns:
data (dict): The JSON content of the response or None if an error occurred.
"""
# Make the request to the given endpoint
response = self.get_request(self.base_url + endpoint)
if response:
try:
# Get the JSON content of the response
data = response.json()
self.logger.info(
"GET %s succeeded with status code %s",
endpoint,
response.status_code,
)
return data
except ValueError:
self.logger.error(
"Failed to decode JSON from response for endpoint: %s", endpoint
)
return None
else:
self.logger.error("No valid response received for endpoint: %s", endpoint)
return None

def write_to_json(self, final_dict, output):
"""
Writes the final dictionary to a JSON file.
Args:
Parameters:
final_dict (dict): The final dictionary to write to the file.
solution (str): The solution name to use as the filename.
output (str): Filename
"""
with open(f"{solution}.json", "w", encoding="utf-8") as f:
with open(output, "w", encoding="utf-8") as f:
json.dump(final_dict, f, indent=4)
self.logger.info("Data written to %s.json", solution)
self.logger.info("Data written to %s", output)

@staticmethod
def create_endpoint_dict(endpoint):
Expand All @@ -216,7 +248,7 @@ def create_endpoint_dict(endpoint):
The value dictionary contains "items" and "children" as empty lists and dictionaries,
respectively, and "endpoint" as the endpoint's endpoint.
Args:
Parameters:
endpoint (dict): The endpoint to create a dictionary for. It should contain "name"
and "endpoint" keys.
Expand Down
213 changes: 116 additions & 97 deletions nac_collector/cisco_client_ise.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import logging

import click
import requests
import urllib3

from nac_collector.cisco_client import CiscoClient

urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

logger = logging.getLogger("main")

# Suppress urllib3 warnings
logging.getLogger("urllib3").setLevel(logging.ERROR)


class CiscoClientISE(CiscoClient):
"""
Expand Down Expand Up @@ -98,6 +101,7 @@ def get_from_endpoints(self, endpoints_yaml_file):
Returns:
dict: The final dictionary containing the data retrieved from the endpoints.
"""

# Load endpoints from the YAML file
logger.info("Loading endpoints from %s", endpoints_yaml_file)
with open(endpoints_yaml_file, "r", encoding="utf-8") as f:
Expand All @@ -109,108 +113,123 @@ def get_from_endpoints(self, endpoints_yaml_file):
# Initialize an empty list for endpoint with children (%v in endpoint['endpoint'])
children_endpoints = []

# Iterate through the endpoints
for endpoint in endpoints:
if all(x not in endpoint["endpoint"] for x in ["%v", "%i"]):
endpoint_dict = CiscoClient.create_endpoint_dict(endpoint)
# Iterate over all endpoints
with click.progressbar(endpoints, label="Processing endpoints") as endpoint_bar:
for endpoint in endpoint_bar:
logger.info("Processing endpoint: %s", endpoint)

response = self.get_request(self.base_url + endpoint["endpoint"])
if all(x not in endpoint["endpoint"] for x in ["%v", "%i"]):
endpoint_dict = CiscoClient.create_endpoint_dict(endpoint)

# Get the JSON content of the response
data = response.json()
# License API returns a list of dictionaries
if isinstance(data, list):
endpoint_dict[endpoint["name"]].append(
{"data": data, "endpoint": endpoint["endpoint"]}
)
data = self.fetch_data(endpoint["endpoint"])

elif data.get("response"):
for i in data.get("response"):
if data is None:
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
{"data": {}, "endpoint": endpoint["endpoint"]}
)
# Pagination for ERS API results
elif data.get("SearchResult"):
ers_data = self.process_ers_api_results(data)
for i in ers_data:

# License API returns a list of dictionaries
elif isinstance(data, list):
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
{"data": data, "endpoint": endpoint["endpoint"]}
)
# Check if response is empty list
elif data.get("response") == []:
endpoint_dict[endpoint["name"]].append(
{"data": {}, "endpoint": endpoint["endpoint"]}
)

# Save results to dictionary
final_dict.update(endpoint_dict)

self.log_response(endpoint["endpoint"], response)

elif "%v" in endpoint["endpoint"]:
children_endpoints.append(endpoint)

# Iterate through the children endpoints
for endpoint in children_endpoints:
parent_endpoint = endpoint["endpoint"].split("/%v")[0]

# Iterate over the dictionary
for _, value in final_dict.items():
index = 0
# Iterate over the items in final_dict[parent_endpoint]
for item in value:
if parent_endpoint == "/".join(
item.get("endpoint").split("/")[:-1]
):
# Initialize an empty list for parent endpoint ids
parent_endpoint_ids = []

# Add the item's id to the list
try:
parent_endpoint_ids.append(item["data"]["id"])
except KeyError:
continue
# Iterate over the parent endpoint ids
for id_ in parent_endpoint_ids:
# Replace '%v' in the endpoint with the id
new_endpoint = endpoint["endpoint"].replace("%v", str(id_))
# Send a GET request to the new endpoint
response = self.get_request(self.base_url + new_endpoint)
# Get the JSON content of the response
data = response.json()

if data.get("response"):
for i in data.get("response"):
# Check if the key exists
if "children" not in value[index]:
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"] = {}
# Check if the key exists
if endpoint["name"] not in value[index]["children"]:
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"][endpoint["name"]] = []

value[index]["children"][endpoint["name"]].append(
{
"data": i,
"endpoint": new_endpoint
+ "/"
+ self.get_id_value(i),
}
)
self.log_response(new_endpoint, response)

index += 1

elif data.get("response"):
for i in data.get("response"):
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
)
# Pagination for ERS API results
elif data.get("SearchResult"):

ers_data = self.process_ers_api_results(data)

for i in ers_data:
endpoint_dict[endpoint["name"]].append(
{
"data": i,
"endpoint": endpoint["endpoint"]
+ "/"
+ self.get_id_value(i),
}
)

# Save results to dictionary
final_dict.update(endpoint_dict)

elif "%v" in endpoint["endpoint"]:
children_endpoints.append(endpoint)

# Iterate over all children endpoints
with click.progressbar(
children_endpoints, label="Processing children endpoints"
) as children_endpoint_bar:
for endpoint in children_endpoint_bar:
logger.info("Processing children endpoint: %s", endpoint)

parent_endpoint = endpoint["endpoint"].split("/%v")[0]

# Iterate over the dictionary
for _, value in final_dict.items():
index = 0
# Iterate over the items in final_dict[parent_endpoint]
for item in value:
if parent_endpoint == "/".join(
item.get("endpoint").split("/")[:-1]
):
# Initialize an empty list for parent endpoint ids
parent_endpoint_ids = []

# Add the item's id to the list
try:
parent_endpoint_ids.append(item["data"]["id"])
except KeyError:
continue
# Iterate over the parent endpoint ids
for id_ in parent_endpoint_ids:
# Replace '%v' in the endpoint with the id
new_endpoint = endpoint["endpoint"].replace(
"%v", str(id_)
)

data = self.fetch_data(new_endpoint)

# Get the JSON content of the response
if data is None:
continue
elif data.get("response"):
for i in data.get("response"):
# Check if the key exists
if "children" not in value[index]:
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"] = {}
# Check if the key exists
if (
endpoint["name"]
not in value[index]["children"]
):
# If the key doesn't exist, create it and initialize it as an empty list
value[index]["children"][
endpoint["name"]
] = []

value[index]["children"][
endpoint["name"]
].append(
{
"data": i,
"endpoint": new_endpoint
+ "/"
+ self.get_id_value(i),
}
)

index += 1
return final_dict

def process_ers_api_results(self, data):
Expand Down Expand Up @@ -252,7 +271,7 @@ def get_id_value(i):
"""
Attempts to get the 'id' or 'name' value from a dictionary.
Args:
Parameters:
i (dict): The dictionary to get the 'id' or 'name' value from.
Returns:
Expand Down
Loading

0 comments on commit be4ed75

Please sign in to comment.