diff --git a/.github/workflows/tests-qs.yml b/.github/workflows/tests-qs.yml new file mode 100644 index 000000000..bad1b97e2 --- /dev/null +++ b/.github/workflows/tests-qs.yml @@ -0,0 +1,27 @@ +name: quantumstrand tests + +on: + push: + branches: [ quantumstrand ] + pull_request: + branches: [ quantumstrand ] + +jobs: + qs_tests: + name: quantumstrand tests + runs-on: ubuntu-22.04 + steps: + - name: Checkout FLOSS with submodule + uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 + with: + submodules: true + - name: Set up Python 3.11 + uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 + with: + python-version: '3.11' + - name: Install FLOSS + run: | + pip install -r requirements.txt + pip install -e .[dev,qs] + - name: Run tests + run: pytest -k qs diff --git a/floss/qs/main.py b/floss/qs/main.py index 4b3e84930..6d77d916d 100644 --- a/floss/qs/main.py +++ b/floss/qs/main.py @@ -482,18 +482,90 @@ def check_is_xor(xor_key: int | None): return () -def check_is_reloc(reloc_offsets: Set[int], string: ExtractedString): - for addr in string.slice.range: - if addr in reloc_offsets: - return ("#reloc",) +class OffsetRanges: + def __init__(self, offsets: Optional[Set[int]] = None, *, _merged_ranges: Optional[List[Tuple[int, int]]] = None): + if _merged_ranges is not None: + self._ranges = _merged_ranges + return + + if not offsets: + self._ranges: List[Tuple[int, int]] = [] + return + + sorted_offsets = sorted(list(offsets)) + + ranges: List[Tuple[int, int]] = [] + start = sorted_offsets[0] + end = start + for offset in sorted_offsets[1:]: + if offset == end + 1: + end = offset + else: + ranges.append((start, end)) + start = offset + end = offset + ranges.append((start, end)) + self._ranges = ranges + + def __contains__(self, offset: int) -> bool: + if not self._ranges: + return False + + # Find the index where the offset would be inserted to maintain order. + index = bisect.bisect_left(self._ranges, (offset, 0)) + + # Check the range at the insertion index. + # This handles cases where the offset is the start of a range. + if index < len(self._ranges): + start, end = self._ranges[index] + if start == offset: + return True + + # Check the range just before the insertion index. + # This handles cases where the offset is within or at the end of a range. + if index > 0: + start, end = self._ranges[index - 1] + if start <= offset <= end: + return True + + return False + + def overlaps(self, start: int, end: int) -> bool: + if not self._ranges: + return False + + # Find the index where the start of the given range would be inserted + index = bisect.bisect_right(self._ranges, (start, 0)) + + # Check the range at index-1 for overlap + if index > 0: + prev_start, prev_end = self._ranges[index - 1] + if max(start, prev_start) <= min(end, prev_end): + return True + + # Check the range at index for overlap + if index < len(self._ranges): + next_start, next_end = self._ranges[index] + if max(start, next_start) <= min(end, next_end): + return True + + return False + + @classmethod + def from_merged_ranges(cls, merged_ranges: List[Tuple[int, int]]): + return cls(_merged_ranges=merged_ranges) + + +def check_is_reloc(reloc_offsets: OffsetRanges, string: ExtractedString): + if reloc_offsets.overlaps(string.slice.range.offset, string.slice.range.end - 1): + return ("#reloc",) return () -def check_is_code(code_offsets: Set[int], string: ExtractedString): - for addr in string.slice.range: - if addr in code_offsets: - return ("#code",) +def check_is_code(code_offsets: OffsetRanges, string: ExtractedString): + if code_offsets.overlaps(string.slice.range.offset, string.slice.range.end - 1): + return ("#code",) return () @@ -803,14 +875,16 @@ class SegmentLayout(Layout): class PELayout(Layout): + model_config = ConfigDict(arbitrary_types_allowed=True) + # xor key if the file was xor decoded xor_key: Optional[int] # file offsets of bytes that are part of the relocation table - reloc_offsets: Set[int] + reloc_offsets: OffsetRanges # file offsets of bytes that are recognized as code - code_offsets: Set[int] + code_offsets: OffsetRanges structures_by_address: Dict[int, Structure] @@ -848,6 +922,68 @@ class ResourceLayout(Layout): pass +def _merge_overlapping_ranges(ranges: List[Tuple[int, int]]) -> List[Tuple[int, int]]: + """ + Merge a list of (start, end) tuples into a list of contiguous ranges. + """ + if not ranges: + return [] + + sorted_ranges = sorted(ranges) + merged_ranges: List[Tuple[int, int]] = [] + for higher in sorted_ranges: + if not merged_ranges: + merged_ranges.append(higher) + else: + lower = merged_ranges[-1] + lower_start, lower_end = lower + higher_start, higher_end = higher + + # test for intersection between lower and higher: + # we know via sorting that lower_start <= higher_start + if higher_start <= lower_end + 1: + upper_bound = max(lower_end, higher_end) + merged_ranges[-1] = (lower_start, upper_bound) + else: + merged_ranges.append(higher) + return merged_ranges + + +def _get_code_ranges(ws: lancelot.Workspace, pe: pefile.PE, slice_: Slice) -> List[Tuple[int, int]]: + """ + Extract and return the raw, unmerged code ranges from a PE file. + """ + base_address = ws.base_address + + # cache because getting the offset is slow + @functools.lru_cache(maxsize=None) + def get_offset_from_rva_cached(rva): + try: + return pe.get_offset_from_rva(rva) + except pefile.PEFormatError as e: + logger.warning("%s", str(e)) + return None + + code_ranges: List[Tuple[int, int]] = [] + for function in ws.get_functions(): + cfg = ws.build_cfg(function) + for bb in cfg.basic_blocks.values(): + va = bb.address + rva = va - base_address + offset = get_offset_from_rva_cached(rva) + if offset is None: + continue + + size = bb.length + + if not slice_.contains_range(offset, size): + logger.warning("lancelot identified code at an invalid location, skipping basic block at 0x%x", rva) + continue + + code_ranges.append((offset, offset + size - 1)) + return code_ranges + + def compute_pe_layout(slice: Slice, xor_key: int | None) -> Layout: data = slice.data @@ -857,7 +993,7 @@ def compute_pe_layout(slice: Slice, xor_key: int | None) -> Layout: raise ValueError("pefile failed to load workspace") from e structures = collect_pe_structures(slice, pe) - reloc_offsets = get_reloc_offsets(slice, pe) + reloc_offsets = OffsetRanges(get_reloc_offsets(slice, pe)) structures_by_address = {} for structure in structures: @@ -872,30 +1008,10 @@ def compute_pe_layout(slice: Slice, xor_key: int | None) -> Layout: raise ValueError("lancelot failed to load workspace") from e # contains the file offsets of bytes that are part of recognized instructions. - code_offsets = set() with timing("lancelot: find code"): - base_address = ws.base_address - for function in ws.get_functions(): - cfg = ws.build_cfg(function) - for bb in cfg.basic_blocks.values(): - va = bb.address - rva = va - base_address - try: - offset = pe.get_offset_from_rva(rva) - except pefile.PEFormatError as e: - logger.warning("%s", str(e)) - continue - - size = bb.length - - if not slice.contains_range(offset, size): - logger.warning( - "lancelot identified code at an invalid location, skipping basic block at 0x%x", rva - ) - continue - - for fo in slice.range.slice(offset, size): - code_offsets.add(fo) + code_ranges = _get_code_ranges(ws, pe, slice) + merged_code_ranges = _merge_overlapping_ranges(code_ranges) + code_offsets = OffsetRanges.from_merged_ranges(merged_code_ranges) layout = PELayout( slice=slice, diff --git a/tests/data b/tests/data index 53e910192..c0eb9b1cd 160000 --- a/tests/data +++ b/tests/data @@ -1 +1 @@ -Subproject commit 53e910192ea6f3f4c825370389393bdd9631580c +Subproject commit c0eb9b1cd428ad10ce450a0a7d673d85dc457e95 diff --git a/tests/test_qs.py b/tests/test_qs.py index b092c8e74..1f28f8b17 100644 --- a/tests/test_qs.py +++ b/tests/test_qs.py @@ -22,7 +22,7 @@ @pytest.fixture def pma_binary_path(): - return CD / "data" / "pma" / "pma0303.exe_" + return CD / "data" / "pma" / "Practical Malware Analysis Lab 03-03.exe_" @pytest.fixture diff --git a/tests/test_qs_code_ranges.py b/tests/test_qs_code_ranges.py new file mode 100644 index 000000000..4c7fdf535 --- /dev/null +++ b/tests/test_qs_code_ranges.py @@ -0,0 +1,137 @@ +import pytest +from unittest.mock import Mock, MagicMock + +import pefile +import lancelot + +from floss.qs.main import ( + Slice, + Range, + _get_code_ranges, + _merge_overlapping_ranges, +) + + +# Tests for _merge_overlapping_ranges +def test_merge_empty_list(): + """Test merging an empty list of ranges.""" + assert _merge_overlapping_ranges([]) == [] + + +def test_merge_no_overlap(): + """Test merging ranges that do not overlap.""" + ranges = [(10, 20), (30, 40), (50, 60)] + assert _merge_overlapping_ranges(ranges) == [(10, 20), (30, 40), (50, 60)] + + +def test_merge_with_overlap(): + """Test merging ranges that partially overlap.""" + ranges = [(10, 20), (15, 25), (30, 40)] + assert _merge_overlapping_ranges(ranges) == [(10, 25), (30, 40)] + + +def test_merge_adjacent(): + """Test merging ranges that are right next to each other.""" + ranges = [(10, 20), (21, 30), (31, 40)] + assert _merge_overlapping_ranges(ranges) == [(10, 40)] + + +def test_merge_fully_contained(): + """Test merging ranges where some are fully contained within others.""" + ranges = [(10, 40), (15, 25), (20, 30)] + assert _merge_overlapping_ranges(ranges) == [(10, 40)] + + +def test_merge_complex_mix(): + """Test a complex mixture of overlapping, adjacent, and contained ranges.""" + ranges = [(50, 60), (10, 20), (18, 30), (35, 40), (39, 55)] + # After sorting: [(10, 20), (18, 30), (35, 40), (39, 55), (50, 60)] + # (10, 20) and (18, 30) -> (10, 30) + # (35, 40) and (39, 55) -> (35, 55) + # (35, 55) and (50, 60) -> (35, 60) + assert _merge_overlapping_ranges(ranges) == [(10, 30), (35, 60)] + + +# Tests for _get_code_ranges +@pytest.fixture +def mock_pe(): + """Fixture for a mocked pefile.PE object.""" + pe = MagicMock(spec=pefile.PE) + + def get_offset_from_rva(rva): + # Simple mapping for testing: offset is just rva + 0x1000 + return rva + 0x1000 + + pe.get_offset_from_rva.side_effect = get_offset_from_rva + return pe + + +@pytest.fixture +def mock_ws(): + """Fixture for a mocked lancelot.Workspace object.""" + ws = MagicMock(spec=lancelot.Workspace) + ws.base_address = 0x400000 + + # Mock functions and basic blocks + func1 = Mock() + func2 = Mock() + ws.get_functions.return_value = [func1, func2] + + bb1 = Mock(address=0x401000, length=0x10) # rva: 0x1000, offset: 0x2000 + bb2 = Mock(address=0x401020, length=0x15) # rva: 0x1020, offset: 0x2020 + bb3 = Mock(address=0x402000, length=0x20) # rva: 0x2000, offset: 0x3000 + + # Setup cfg for each function + cfg1 = Mock(basic_blocks={bb1.address: bb1, bb2.address: bb2}) + cfg2 = Mock(basic_blocks={bb3.address: bb3}) + + def build_cfg(func): + if func == func1: + return cfg1 + return cfg2 + + ws.build_cfg.side_effect = build_cfg + return ws + + +def test_get_code_ranges_basic(mock_ws, mock_pe): + """Test basic extraction of code ranges.""" + # Slice covers the entire mock file + slice_ = Slice(buf=b"", range=Range(offset=0, length=0x5000)) + ranges = _get_code_ranges(mock_ws, mock_pe, slice_) + + assert ranges == [ + (0x2000, 0x200F), # bb1: offset 0x2000, size 0x10 + (0x2020, 0x2034), # bb2: offset 0x2020, size 0x15 + (0x3000, 0x301F), # bb3: offset 0x3000, size 0x20 + ] + + +def test_get_code_ranges_skips_invalid_offset(mock_ws, mock_pe): + """Test that it skips basic blocks that fall outside the slice.""" + # Slice is small and only covers the first basic block + slice_ = Slice(buf=b"", range=Range(offset=0, length=0x2010)) + ranges = _get_code_ranges(mock_ws, mock_pe, slice_) + + # Only bb1 should be included + assert ranges == [(0x2000, 0x200F)] + + +def test_get_code_ranges_handles_pe_error(mock_ws, mock_pe): + """Test that it handles PEFormatError when getting an offset.""" + # Make one of the RVA lookups fail + def get_offset_from_rva_with_error(rva): + if rva == 0x1020: # Corresponds to bb2 + raise pefile.PEFormatError("Test Error") + return rva + 0x1000 + + mock_pe.get_offset_from_rva.side_effect = get_offset_from_rva_with_error + + slice_ = Slice(buf=b"", range=Range(offset=0, length=0x5000)) + ranges = _get_code_ranges(mock_ws, mock_pe, slice_) + + # bb2 should be skipped + assert ranges == [ + (0x2000, 0x200F), + (0x3000, 0x301F), + ] diff --git a/tests/test_qs_offset_ranges.py b/tests/test_qs_offset_ranges.py new file mode 100644 index 000000000..f3afd6a26 --- /dev/null +++ b/tests/test_qs_offset_ranges.py @@ -0,0 +1,113 @@ +import pytest + +from floss.qs.main import OffsetRanges + + +def test_offset_ranges_init_empty(): + """Test initialization with no offsets.""" + offsets = set() + ranges = OffsetRanges(offsets) + assert ranges._ranges == [] + + +def test_offset_ranges_init(): + """Test initialization with a mix of contiguous and non-contiguous offsets.""" + offsets = {0, 1, 2, 5, 6, 8, 10} + ranges = OffsetRanges(offsets) + assert ranges._ranges == [(0, 2), (5, 6), (8, 8), (10, 10)] + + +def test_offset_ranges_init_single_range(): + """Test initialization with a single contiguous block of offsets.""" + offsets = {10, 11, 12, 13, 14} + ranges = OffsetRanges(offsets) + assert ranges._ranges == [(10, 14)] + + +def test_offset_ranges_from_merged_ranges(): + """Test the from_merged_ranges class method.""" + merged = [(10, 20), (30, 40)] + ranges = OffsetRanges.from_merged_ranges(merged) + assert ranges._ranges == [(10, 20), (30, 40)] + + +@pytest.fixture +def sample_ranges(): + """Provides a standard OffsetRanges instance for testing.""" + # Ranges will be: (10, 15), (20, 25), (30, 30) + offsets = {10, 11, 12, 13, 14, 15, 20, 21, 22, 23, 24, 25, 30} + return OffsetRanges(offsets) + + +def test_contains_empty(sample_ranges): + """Test __contains__ on an empty OffsetRanges instance.""" + empty_ranges = OffsetRanges(set()) + assert 10 not in empty_ranges + + +def test_contains_inside(sample_ranges): + """Test __contains__ for an offset well within a range.""" + assert 12 in sample_ranges + assert 23 in sample_ranges + + +def test_contains_edges(sample_ranges): + """Test __contains__ for offsets at the exact start and end of ranges.""" + assert 10 in sample_ranges # Start of first range + assert 15 in sample_ranges # End of first range + assert 20 in sample_ranges # Start of second range + assert 25 in sample_ranges # End of second range + assert 30 in sample_ranges # Single-point range + + +def test_contains_outside(sample_ranges): + """Test __contains__ for offsets outside of any range.""" + assert 9 not in sample_ranges # Before first range + assert 16 not in sample_ranges # Between ranges + assert 29 not in sample_ranges # Between ranges + assert 31 not in sample_ranges # After last range + + +def test_overlaps_empty(sample_ranges): + """Test overlaps on an empty OffsetRanges instance.""" + empty_ranges = OffsetRanges(set()) + assert not empty_ranges.overlaps(10, 20) + + +def test_overlaps_fully_contained(sample_ranges): + """Test overlaps where the query range is fully inside an existing range.""" + assert sample_ranges.overlaps(11, 14) # Fully inside (10, 15) + assert sample_ranges.overlaps(21, 22) # Fully inside (20, 25) + + +def test_overlaps_contains_full_range(sample_ranges): + """Test overlaps where the query range fully contains an existing range.""" + assert sample_ranges.overlaps(9, 16) # Contains (10, 15) + assert sample_ranges.overlaps(19, 26) # Contains (20, 25) + assert sample_ranges.overlaps(29, 31) # Contains (30, 30) + + +def test_overlaps_start(sample_ranges): + """Test overlaps where the query range overlaps the beginning of an existing range.""" + assert sample_ranges.overlaps(8, 12) # Overlaps start of (10, 15) + assert sample_ranges.overlaps(18, 20) # Touches start of (20, 25) + + +def test_overlaps_end(sample_ranges): + """Test overlaps where the query range overlaps the end of an existing range.""" + assert sample_ranges.overlaps(14, 17) # Overlaps end of (10, 15) + assert sample_ranges.overlaps(25, 28) # Touches end of (20, 25) + + +def test_overlaps_multiple_ranges(sample_ranges): + """Test overlaps where the query range spans across multiple existing ranges.""" + assert sample_ranges.overlaps(12, 22) # Spans from first to second range + assert sample_ranges.overlaps(14, 30) # Spans all three ranges + + +def test_no_overlap(sample_ranges): + """Test overlaps for ranges that do not overlap at all.""" + assert not sample_ranges.overlaps(0, 8) # Before all ranges + assert not sample_ranges.overlaps(16, 19) # Between ranges + assert not sample_ranges.overlaps(26, 29) # Between ranges + assert not sample_ranges.overlaps(31, 40) # After all ranges diff --git a/tests/test_qs_pma0101.py b/tests/test_qs_pma0101.py new file mode 100644 index 000000000..868a8e940 --- /dev/null +++ b/tests/test_qs_pma0101.py @@ -0,0 +1,92 @@ +from pathlib import Path + +import pytest + +from floss.qs.main import ( + Slice, + compute_layout, + load_databases, + collect_strings, + extract_layout_strings, +) + + +@pytest.fixture(scope="module") +def pma0101_layout(): + """ + Provides the analyzed layout. + The analysis pipeline (string extraction, tagging, structure marking) + is run once for all tests in this module. + """ + binary_path = Path("tests") / Path("data") / Path("pma") / Path("Practical Malware Analysis Lab 01-01.dll_") + slice_buf = binary_path.read_bytes() + file_slice = Slice.from_bytes(slice_buf) + layout = compute_layout(file_slice) + extract_layout_strings(layout, 6) + taggers = load_databases() + layout.tag_strings(taggers) + layout.mark_structures() + return layout + + +def find_string(layout, text): + """Helper to find a specific string in the layout.""" + all_strings = collect_strings(layout) + found = [s for s in all_strings if s.string.string == text] + return found[0] if found else None + + +def test_pe_layout(pma0101_layout): + assert pma0101_layout.name == "pe" + + +def test_header_strings(pma0101_layout): + dos_mode_str = find_string(pma0101_layout, "!This program cannot be run in DOS mode.") + assert dos_mode_str is not None + assert "#common" in dos_mode_str.tags + + rdata_str = find_string(pma0101_layout, "@.data") + assert rdata_str is not None + assert "#common" in rdata_str.tags + assert rdata_str.structure == "section header" + + reloc_str = find_string(pma0101_layout, ".reloc") + assert reloc_str is not None + assert "#common" in reloc_str.tags + assert reloc_str.structure == "section header" + + +def test_rdata_strings(pma0101_layout): + kernel32_str = find_string(pma0101_layout, "KERNEL32.dll") + assert kernel32_str is not None + assert "#winapi" in kernel32_str.tags + assert kernel32_str.structure == "import table" + + msvcrt_str = find_string(pma0101_layout, "MSVCRT.dll") + assert msvcrt_str is not None + assert "#winapi" in msvcrt_str.tags + assert msvcrt_str.structure == "import table" + + initterm_str = find_string(pma0101_layout, "_initterm") + assert initterm_str is not None + assert "#winapi" in initterm_str.tags + assert "#code-junk" in initterm_str.tags + assert initterm_str.structure == "import table" + + +def test_data_strings(pma0101_layout): + ip_str = find_string(pma0101_layout, "127.26.152.13") + assert ip_str is not None + + garbage_str = find_string(pma0101_layout, "SADFHUHF") + assert garbage_str is not None + + +def test_strings(pma0101_layout): + all_strings = collect_strings(pma0101_layout) + + assert len(all_strings) == 21 + + # assert count of expected strings not tagged as #code or #reloc + filtered_strings = [s for s in all_strings if not s.tags.intersection({"#code", "#reloc"})] + assert len(filtered_strings) == 17