Skip to content

Commit

Permalink
Fix unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
cadosecurity committed Jan 29, 2024
1 parent 104c4ac commit 6ee741d
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 99 deletions.
2 changes: 1 addition & 1 deletion cloudgrep/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def main() -> None:
"--query",
type=list_of_strings,
help="Text to search for. Will be parsed as a Regex. E.g. example.com",
required=False
required=False,
)
parser.add_argument(
"-v",
Expand Down
17 changes: 12 additions & 5 deletions cloudgrep/cloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ def download_from_azure(
log_format: Optional[str] = None,
log_properties: List[str] = [],
json_output: Optional[bool] = False,

) -> int:
"""Download every file in the container from azure
Returns number of matched files"""
Expand All @@ -84,11 +83,19 @@ def download_file(key: str) -> None:
try:
blob_client = container_client.get_blob_client(key)
with open(tmp.name, "wb") as my_blob:

blob_data = blob_client.download_blob()
blob_data.readinto(my_blob)
blob_data.readinto(my_blob)
matched = Search().search_file(
tmp.name, key, query, hide_filenames, yara_rules,log_format, log_properties, json_output, account_name
tmp.name,
key,
query,
hide_filenames,
yara_rules,
log_format,
log_properties,
json_output,
account_name,
)
if matched:
nonlocal matched_count
Expand Down Expand Up @@ -231,7 +238,7 @@ def get_azure_objects(
blobs = container_client.list_blobs(name_starts_with=prefix)

for blob in blobs:

if self.filter_object_azure(
blob,
key_contains,
Expand Down
7 changes: 2 additions & 5 deletions cloudgrep/cloudgrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class CloudGrep:
def load_queries(self, file: str) -> List[str]:
"""Load in a list of queries from a file"""
with open(file, "r") as f:
return ([line.strip() for line in f.readlines() if len(line.strip())])
return [line.strip() for line in f.readlines() if len(line.strip())]

def search(
self,
Expand Down Expand Up @@ -100,10 +100,7 @@ def search(
account_name, container_name, prefix, key_contains, parsed_from_date, parsed_end_date, file_size
)
)
if log_format != None:
logging.warning(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...")
else:
print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...")
print(f"Searching {len(matching_keys)} files in {account_name}/{container_name} for {query}...")

Cloud().download_from_azure(
account_name,
Expand Down
113 changes: 59 additions & 54 deletions cloudgrep/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ def get_all_strings_line(self, file_path: str) -> List[str]:
"""Get all the strings from a file line by line
We do this instead of f.readlines() as this supports binary files too
"""
print("getting the strings line with bytes")
with open(file_path, "rb") as f:
read_bytes = f.read()
b = read_bytes.decode("utf-8", "ignore")
Expand All @@ -25,14 +24,13 @@ def get_all_strings_line(self, file_path: str) -> List[str]:
def print_match(self, matched_line_dict: dict, hide_filenames: bool, json_output: Optional[bool]) -> None:
"""Print matched line"""
if json_output:
# print("jsoni")
if hide_filenames:
matched_line_dict.pop("key_name")
try:

print(json.dumps(matched_line_dict))
except TypeError:

print(str(matched_line_dict))
else:
line = ""
Expand All @@ -42,10 +40,9 @@ def print_match(self, matched_line_dict: dict, hide_filenames: bool, json_output
line = f"{matched_line_dict['match_rule']}: {matched_line_dict['match_strings']}"

if not hide_filenames:

print(f"{matched_line_dict['key_name']}: {matched_line_dict}")
print(f"{matched_line_dict['key_name']}: {line}")
else:

print(line)

def search_logs(
Expand Down Expand Up @@ -79,7 +76,7 @@ def search_logs(

# Step into property/properties to get to final list of lines for per-line searching.
if log_properties != None:

for log_property in log_properties:
if line_parsed:
line_parsed = line_parsed.get(log_property, None)
Expand All @@ -91,7 +88,7 @@ def search_logs(
# Perform per-line searching.
for record in line_parsed:
if re.search(search, json.dumps(record)):

matched_line_dict = {"key_name": key_name, "query": search, "line": record}
self.print_match(matched_line_dict, hide_filenames, json_output)

Expand All @@ -109,11 +106,13 @@ def search_line(
matched = False
for cur_search in search:
if re.search(cur_search, line):

if log_format != None:
self.search_logs(line, key_name, cur_search, hide_filenames, log_format, log_properties, json_output)
self.search_logs(
line, key_name, cur_search, hide_filenames, log_format, log_properties, json_output
)
else:

matched_line_dict = {"key_name": key_name, "query": cur_search, "line": line}
self.print_match(matched_line_dict, hide_filenames, json_output)
matched = True
Expand All @@ -140,35 +139,32 @@ def search_file(
log_properties: List[str] = [],
json_output: Optional[bool] = False,
account_name: Optional[str] = None,


) -> bool:
"""Regex search of the file line by line"""
matched = False
# print("search")

logging.info(f"Searching {file_name} for {search}")
if yara_rules:
print("yara")
matched = self.yara_scan_file(file_name, key_name, hide_filenames, yara_rules, json_output)
else:
if key_name.endswith(".gz"):

if key_name.endswith(".gz"):
with gzip.open(file_name, "rt") as f:

if account_name:
json_data = json.load(f)
for i in range(len(json_data)):
data = json_data[i]
line = json.dumps(data)

if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
matched = True
try:
# Try to load the file as JSON
json_data = json.load(f)
for i in range(len(json_data)):
data = json_data[i]
line = json.dumps(data)
if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
matched = True
except json.JSONDecodeError:
logging.info(f"File {file_name} is not JSON")
else:

for line in f:

if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
Expand All @@ -183,35 +179,44 @@ def search_file(
if os.path.isfile(os.path.join(tempdir, filename)):
with open(os.path.join(tempdir, filename)) as f:
if account_name:
if account_name:
json_data = json.load(f)
for i in range(len(json_data)):
data = json_data[i]
line = json.dumps(data)

if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
matched = True
if account_name:
try:
json_data = json.load(f)
for i in range(len(json_data)):
data = json_data[i]
line = json.dumps(data)

if self.search_line(
key_name,
search,
hide_filenames,
line,
log_format,
log_properties,
json_output,
):
matched = True
except json.JSONDecodeError:
logging.info(f"File {file_name} is not JSON")
else:
for line in f:
if self.search_line(
f"{key_name}/{filename}",
search,
hide_filenames,
line,
log_format,
log_properties,
json_output,
):
matched = True
for line in f:
if self.search_line(
f"{key_name}/{filename}",
search,
hide_filenames,
line,
log_format,
log_properties,
json_output,
):
matched = True
else:

for line in self.get_all_strings_line(file_name):

if self.search_line(
key_name, search, hide_filenames, line, log_format, log_properties, json_output
):
matched = True

return matched
return matched
69 changes: 35 additions & 34 deletions tests/test_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def test_e2e(self) -> None:
assert len(matching_keys) == 3

print(f"Checking we only get one search hit in: {matching_keys}")
hits = Cloud().download_from_s3_multithread(_BUCKET, matching_keys, _QUERY, False, None) # type: ignore
hits = Cloud().download_from_s3_multithread(_BUCKET, matching_keys, _QUERY, False, None) # type: ignore
assert hits == 3

print("Testing with multiple queries from a file")
Expand Down Expand Up @@ -118,14 +118,14 @@ def test_returns_true_if_all_conditions_are_met(self) -> None:

self.assertTrue(result)

# # returns a string with the contents of the file
# returns a string with the contents of the file
def test_returns_string_with_file_contents(self) -> None:
file = "queries.txt"
with open(file, "w") as f:
f.write("query1\nquery2\nquery3")
queries = CloudGrep().load_queries(file)
self.assertIsInstance(queries, List)
self.assertEqual(queries, ["query1", "query2", "query3"] )
self.assertEqual(queries, ["query1", "query2", "query3"])

# Given a valid file name, key name, and yara rules, the method should successfully match the file against the rules and print only the rule name and matched strings if hide_filenames is True.
def test_yara(self) -> None:
Expand Down Expand Up @@ -202,43 +202,44 @@ def test_search_cloudtrail(self) -> None:
self.assertIn("SignatureVersion", output)
self.assertTrue(json.loads(output))

def test_search_azure(self) -> None:
# Arrange
log_format = "json"
log_properties = ["data"]

# Test it doesnt crash on bad json
Search().search_file(
f"{BASE_PATH}/data/bad_azure.json",
"bad_azure.json",
["azure.gz"],
False,
None,
log_format,
log_properties,
)
def test_search_azure(self) -> None:
# Arrange
log_format = "json"
log_properties = ["data"]

# Test it doesnt crash on bad json
Search().search_file(
f"{BASE_PATH}/data/bad_azure.json",
"bad_azure.json",
["azure.gz"],
False,
None,
log_format,
log_properties,
)
Search().search_file(
f"{BASE_PATH}/data/azure.json",
"azure.json",
["azure.gz"],
False,
None,
log_format,
log_properties,
)
with patch("sys.stdout", new=StringIO()) as fake_out:
Search().search_file(
f"{BASE_PATH}/data/azure.json",
"azure.json",
f"{BASE_PATH}/data/azure_singleline.json",
"azure_singleline.json",
["azure.gz"],
False,
None,
log_format,
log_properties,
True,
)
with patch("sys.stdout", new=StringIO()) as fake_out:
Search().search_file(
f"{BASE_PATH}/data/azure_singleline.json",
"azure_singleline.json",
["azure.gz"],
False,
None,
log_format,
log_properties,
True,
)
output = fake_out.getvalue().strip()
output = fake_out.getvalue().strip()

# Assert we can parse the output
self.assertIn("SignatureVersion", output)
self.assertTrue(json.loads(output))
# Assert we can parse the output
self.assertIn("SignatureVersion", output)
self.assertTrue(json.loads(output))

0 comments on commit 6ee741d

Please sign in to comment.