Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
cadosecurity committed Dec 4, 2023
1 parent 2c8be99 commit a8aebf5
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 70 deletions.
47 changes: 14 additions & 33 deletions cloudgrep/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

# Define a custom argument type for a list of strings
def list_of_strings(arg):
return arg.split(',')
return arg.split(",")


def main() -> None:
parser = argparse.ArgumentParser(
Expand All @@ -18,10 +19,7 @@ def main() -> None:
parser.add_argument("-cn", "--container-name", help="Azure Container Name to Search", required=False)
parser.add_argument("-gb", "--google-bucket", help="Google Cloud Bucket to Search", required=False)
parser.add_argument(
"-q",
"--query",
help="Text to search for. Will be parsed as a Regex. E.g. example.com",
required=False
"-q", "--query", help="Text to search for. Will be parsed as a Regex. E.g. example.com", required=False
)
parser.add_argument(
"-v",
Expand All @@ -43,10 +41,7 @@ def main() -> None:
default="",
)
parser.add_argument(
"-f",
"--filename",
help="Optionally filter on Objects that match a keyword. E.g. .log.gz ",
required=False
"-f", "--filename", help="Optionally filter on Objects that match a keyword. E.g. .log.gz ", required=False
)
parser.add_argument(
"-s",
Expand All @@ -73,46 +68,32 @@ def main() -> None:
help="Set an AWS profile to use. E.g. default, dev, prod.",
required=False,
)
parser.add_argument("-d", "--debug", help="Enable Debug logging. ", action="store_true", required=False)
parser.add_argument(
"-d",
"--debug",
help="Enable Debug logging. ",
action="store_true",
required=False
)
parser.add_argument(
"-hf",
"--hide_filenames",
help="Dont show matching filenames. ",
action="store_true",
required=False
"-hf", "--hide_filenames", help="Dont show matching filenames. ", action="store_true", required=False
)
parser.add_argument(
"-lt",
"--log_type",
help="Return individual matching log entries based on pre-defined log types, otherwise custom log_format and log_properties can be used. E.g. cloudtrail. ",
required=False
required=False,
)
parser.add_argument(
"-lf",
"--log_format",
help="Define custom log format of raw file to parse before applying search logic. Used if --log_type is not defined. E.g. json. ",
required=False
required=False,
)
parser.add_argument(
"-lp",
"--log_properties",
type=list_of_strings,
help="Define custom list of properties to traverse to dynamically extract final list of log records. Used if --log_type is not defined. E.g. [""Records""]. ",
required=False
)
parser.add_argument(
"-jo",
"--json_output",
help="Output as JSON.",
required=False,
default=False
help="Define custom list of properties to traverse to dynamically extract final list of log records. Used if --log_type is not defined. E.g. ["
"Records"
"]. ",
required=False,
)
parser.add_argument("-jo", "--json_output", help="Output as JSON.", required=False, default=False)
args = vars(parser.parse_args())

if len(sys.argv) == 1:
Expand Down Expand Up @@ -143,7 +124,7 @@ def main() -> None:
args["log_format"],
args["log_properties"],
args["profile"],
args["json_output"]
args["json_output"],
)


Expand Down
12 changes: 9 additions & 3 deletions cloudgrep/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def download_file(key: str) -> None:
with tempfile.NamedTemporaryFile() as tmp:
logging.info(f"Downloading {bucket} {key} to {tmp.name}")
s3.download_file(bucket, key, tmp.name)
matched = Search().search_file(tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output)
matched = Search().search_file(
tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output
)
if matched:
nonlocal matched_count
matched_count += 1
Expand Down Expand Up @@ -82,7 +84,9 @@ def download_file(key: str) -> None:
with open(tmp.name, "wb") as my_blob:
blob_data = blob_client.download_blob()
blob_data.readinto(my_blob)
matched = Search().search_file(tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output)
matched = Search().search_file(
tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output
)
if matched:
nonlocal matched_count
matched_count += 1
Expand Down Expand Up @@ -118,7 +122,9 @@ def download_file(key: str) -> None:
logging.info(f"Downloading {bucket} {key} to {tmp.name}")
blob = bucket_gcp.get_blob(key)
blob.download_to_filename(tmp.name)
matched = Search().search_file(tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output)
matched = Search().search_file(
tmp.name, key, query, hide_filenames, yara_rules, log_format, log_properties, json_output
)
if matched:
nonlocal matched_count
matched_count += 1
Expand Down
35 changes: 28 additions & 7 deletions cloudgrep/cloudgrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from typing import Optional
import logging
from cloudgrep.cloud import Cloud
#import yara # type: ignore

# import yara # type: ignore


class CloudGrep:
Expand Down Expand Up @@ -46,7 +47,9 @@ def search(
log_format = "json"
log_properties = ["Records"]
case _:
logging.error(f"Invalid log_type value ('{log_type}') unhandled in switch statement in 'search' function.")
logging.error(
f"Invalid log_type value ('{log_type}') unhandled in switch statement in 'search' function."
)

if yara_file:
logging.log(f"Loading yara rules from {yara_file}")
Expand All @@ -73,13 +76,19 @@ def search(
s3_client = boto3.client("s3")
region = s3_client.get_bucket_location(Bucket=bucket)
if log_format != None:
logging.warning(f"Bucket is in region: {region['LocationConstraint']} : Search from the same region to avoid egress charges.")
logging.warning(
f"Bucket is in region: {region['LocationConstraint']} : Search from the same region to avoid egress charges."
)
logging.warning(f"Searching {len(matching_keys)} files in {bucket} for {query}...")

else:
print(f"Bucket is in region: {region['LocationConstraint']} : Search from the same region to avoid egress charges.")
print(
f"Bucket is in region: {region['LocationConstraint']} : Search from the same region to avoid egress charges."
)
print(f"Searching {len(matching_keys)} files in {bucket} for {query}...")
Cloud().download_from_s3_multithread(bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output)
Cloud().download_from_s3_multithread(
bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output
)

if account_name and container_name:
matching_keys = list(
Expand All @@ -88,7 +97,17 @@ def search(
)
)
print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...")
Cloud().download_from_azure(account_name, container_name, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output)
Cloud().download_from_azure(
account_name,
container_name,
matching_keys,
query,
hide_filenames,
yara_rules,
log_format,
log_properties,
json_output,
)

if google_bucket:
matching_keys = list(
Expand All @@ -97,4 +116,6 @@ def search(

print(f"Searching {len(matching_keys)} files in {google_bucket} for {query}...")

Cloud().download_from_google(google_bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output)
Cloud().download_from_google(
google_bucket, matching_keys, query, hide_filenames, yara_rules, log_format, log_properties, json_output
)
45 changes: 25 additions & 20 deletions cloudgrep/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import json
import csv


class Search:
def get_all_strings_line(self, file_path: str) -> List[str]:
"""Get all the strings from a file line by line
Expand All @@ -32,10 +33,10 @@ def print_match(self, matched_line_dict: dict, hide_filenames: bool, json_output
else:
line = ""
if "line" in matched_line_dict:
line = matched_line_dict['line']
line = matched_line_dict["line"]
if "match_rule" in matched_line_dict:
line = f"{matched_line_dict['match_rule']}: {matched_line_dict['match_strings']}"

if not hide_filenames:
print(f"{matched_line_dict['key_name']}: {line}")
else:
Expand Down Expand Up @@ -63,7 +64,9 @@ def search_logs(
case "csv":
line_parsed = csv.DictReader(line)
case _:
logging.error(f"Invalid log_format value ('{log_format}') in switch statement in 'search_logs' function, so defaulting to 'json'.")
logging.error(
f"Invalid log_format value ('{log_format}') in switch statement in 'search_logs' function, so defaulting to 'json'."
)
# Default to JSON format.
log_format = "json"
line_parsed = json.loads(line)
Expand All @@ -81,12 +84,9 @@ def search_logs(
# Perform per-line searching.
for record in line_parsed:
if re.search(search, json.dumps(record)):
matched_line_dict = {
"key_name": key_name,
"line" : record
}
matched_line_dict = {"key_name": key_name, "line": record}
self.print_match(matched_line_dict, hide_filenames, json_output)

def search_line(
self,
key_name: str,
Expand All @@ -102,10 +102,7 @@ def search_line(
if log_format != None:
self.search_logs(line, key_name, search, hide_filenames, log_format, log_properties, json_output)
else:
matched_line_dict = {
"key_name": key_name,
"line" : line
}
matched_line_dict = {"key_name": key_name, "line": line}
self.print_match(matched_line_dict, hide_filenames, json_output)
return True
return False
Expand All @@ -115,11 +112,7 @@ def yara_scan_file(self, file_name: str, key_name: str, hide_filenames: bool, ya
matches = yara_rules.match(file_name)
if matches:
for match in matches:
matched_line_dict = {
"key_name": key_name,
"match_rule": match.rule,
"match_strings": match.strings
}
matched_line_dict = {"key_name": key_name, "match_rule": match.rule, "match_strings": match.strings}
self.print_match(matched_line_dict, hide_filenames, json_output)
matched = True
return matched
Expand All @@ -145,7 +138,9 @@ def search_file(
if key_name.endswith(".gz"):
with gzip.open(file_name, "rt") as f:
for line in f:
if self.search_line(key_name, search, hide_filenames, line, log_format, log_properties, json_output):
if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
matched = True
elif key_name.endswith(".zip"):
with tempfile.TemporaryDirectory() as tempdir:
Expand All @@ -157,11 +152,21 @@ def search_file(
if os.path.isfile(os.path.join(tempdir, filename)):
with open(os.path.join(tempdir, filename)) as f:
for line in f:
if self.search_line("{key_name}/{filename}", search, hide_filenames, line, log_format, log_properties, json_output):
if self.search_line(
"{key_name}/{filename}",
search,
hide_filenames,
line,
log_format,
log_properties,
json_output,
):
matched = True
else:
for line in self.get_all_strings_line(file_name):
if self.search_line(key_name, search, hide_filenames, line, log_format, log_properties, json_output):
if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
matched = True

return matched
40 changes: 33 additions & 7 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,16 @@ def test_yara(self) -> None:
self.assertTrue(matched)
self.assertEqual(output, "{'match_rule': 'rule_name', 'match_strings': [$a]}")


# Unit test to check that all output is json parseable
def test_json_output(self) -> None:
# Arrange
search = Search()

# Act
with patch("sys.stdout", new=StringIO()) as fake_out:
found = Search().search_file(f"{BASE_PATH}/data/000000.gz", "000000.gz", "Running on machine", False, None, None, None, True)
found = Search().search_file(
f"{BASE_PATH}/data/000000.gz", "000000.gz", "Running on machine", False, None, None, None, True
)
output = fake_out.getvalue().strip()

# Assert we can parse the output
Expand All @@ -162,13 +163,38 @@ def test_search_cloudtrail(self) -> None:

# Act
# Test it doesnt crash on bad json
found = Search().search_file(f"{BASE_PATH}/data/bad_cloudtrail.json", "bad_cloudtrail.json", "Running on machine", False, None, log_format, log_properties)
found = Search().search_file(f"{BASE_PATH}/data/cloudtrail.json", "cloudtrail.json", "Running on machine", False, None, log_format, log_properties)
found = Search().search_file(
f"{BASE_PATH}/data/bad_cloudtrail.json",
"bad_cloudtrail.json",
"Running on machine",
False,
None,
log_format,
log_properties,
)
found = Search().search_file(
f"{BASE_PATH}/data/cloudtrail.json",
"cloudtrail.json",
"Running on machine",
False,
None,
log_format,
log_properties,
)
# Get the output for a hit
with patch("sys.stdout", new=StringIO()) as fake_out:
found = Search().search_file(f"{BASE_PATH}/data/cloudtrail_singleline.json", "cloudtrail_singleline.json", "SignatureVersion", False, None, log_format, log_properties, True)
found = Search().search_file(
f"{BASE_PATH}/data/cloudtrail_singleline.json",
"cloudtrail_singleline.json",
"SignatureVersion",
False,
None,
log_format,
log_properties,
True,
)
output = fake_out.getvalue().strip()

# Assert we can parse the output
self.assertIn("SignatureVersion", output)
self.assertTrue(json.loads(output))
self.assertTrue(json.loads(output))

0 comments on commit a8aebf5

Please sign in to comment.