diff --git a/vulnerabilities/importer.py b/vulnerabilities/importer.py index 6fdfb3ef2..9edf63d00 100644 --- a/vulnerabilities/importer.py +++ b/vulnerabilities/importer.py @@ -12,7 +12,6 @@ import logging import os import shutil -import tempfile import traceback import xml.etree.ElementTree as ET from pathlib import Path @@ -23,9 +22,7 @@ from typing import Set from typing import Tuple -from binaryornot.helpers import is_binary_string -from git import DiffIndex -from git import Repo +from fetchcode.vcs import fetch_via_vcs from license_expression import Licensing from packageurl import PackageURL from univers.version_range import VersionRange @@ -312,193 +309,37 @@ def advisory_data(self) -> Iterable[AdvisoryData]: raise NotImplementedError -# TODO: Needs rewrite -class GitImporter(Importer): - def validate_configuration(self) -> None: +class ForkError(Exception): + pass - if not self.config.create_working_directory and self.config.working_directory is None: - self.error( - '"create_working_directory" is not set but "working_directory" is set to ' - "the default, which calls tempfile.mkdtemp()" - ) - if not self.config.create_working_directory and not os.path.exists( - self.config.working_directory - ): - self.error( - '"working_directory" does not contain an existing directory and' - '"create_working_directory" is not set' - ) - - if not self.config.remove_working_directory and self.config.working_directory is None: - self.error( - '"remove_working_directory" is not set and "working_directory" is set to ' - "the default, which calls tempfile.mkdtemp()" - ) +class GitImporter(Importer): + def __init__(self, repo_url): + super().__init__() + self.repo_url = repo_url + self.vcs_response = None def __enter__(self): - self._ensure_working_directory() - self._ensure_repository() - - def __exit__(self, exc_type, exc_val, exc_tb): - if self.config.remove_working_directory: - shutil.rmtree(self.config.working_directory) - - def file_changes( - self, - subdir: str = None, - recursive: bool = False, - file_ext: Optional[str] = None, - ) -> Tuple[Set[str], Set[str]]: - """ - Returns all added and modified files since last_run_date or cutoff_date (whichever is more - recent). - - :param subdir: filter by files in this directory - :param recursive: whether to include files in subdirectories - :param file_ext: filter files by this extension - :return: The first set contains (absolute paths to) added files, the second one modified - files - """ - if subdir is None: - working_dir = self.config.working_directory - else: - working_dir = os.path.join(self.config.working_directory, subdir) + super().__enter__() + self.clone() + return self - path = Path(working_dir) + def __exit__(self): + self.vcs_response.delete() - if self.config.last_run_date is None and self.config.cutoff_date is None: - if recursive: - glob = "**/*" - else: - glob = "*" - - if file_ext: - glob = f"{glob}.{file_ext}" - - return {str(p) for p in path.glob(glob) if p.is_file()}, set() - - return self._collect_file_changes(subdir=subdir, recursive=recursive, file_ext=file_ext) - - def _collect_file_changes( - self, - subdir: Optional[str], - recursive: bool, - file_ext: Optional[str], - ) -> Tuple[Set[str], Set[str]]: - - added_files, updated_files = set(), set() - - # find the most ancient commit we need to diff with - cutoff_commit = None - for commit in self._repo.iter_commits(self._repo.head): - if commit.committed_date < self.cutoff_timestamp: - break - cutoff_commit = commit - - if cutoff_commit is None: - return added_files, updated_files - - def _is_binary(d: DiffIndex): - return is_binary_string(d.b_blob.data_stream.read(1024)) - - for d in cutoff_commit.diff(self._repo.head.commit): - if not _include_file(d.b_path, subdir, recursive, file_ext) or _is_binary(d): - continue - - abspath = os.path.join(self.config.working_directory, d.b_path) - if d.new_file: - added_files.add(abspath) - elif d.a_blob and d.b_blob: - if d.a_path != d.b_path: - # consider moved files as added - added_files.add(abspath) - elif d.a_blob != d.b_blob: - updated_files.add(abspath) - - # Any file that has been added and then updated inside the window of the git history we - # looked at, should be considered "added", not "updated", since it does not exist in the - # database yet. - updated_files = updated_files - added_files - - return added_files, updated_files - - def _ensure_working_directory(self) -> None: - if self.config.working_directory is None: - self.config.working_directory = tempfile.mkdtemp() - elif self.config.create_working_directory and not os.path.exists( - self.config.working_directory - ): - os.mkdir(self.config.working_directory) - - def _ensure_repository(self) -> None: - if not os.path.exists(os.path.join(self.config.working_directory, ".git")): - self._clone_repository() - return - self._repo = Repo(self.config.working_directory) - - if self.config.branch is None: - self.config.branch = str(self._repo.active_branch) - branch = self.config.branch - self._repo.head.reference = self._repo.heads[branch] - self._repo.head.reset(index=True, working_tree=True) - - remote = self._find_or_add_remote() - self._update_from_remote(remote, branch) - - def _clone_repository(self) -> None: - kwargs = {} - if self.config.branch: - kwargs["branch"] = self.config.branch - - self._repo = Repo.clone_from( - self.config.repository_url, self.config.working_directory, **kwargs - ) - - def _find_or_add_remote(self): - remote = None - for r in self._repo.remotes: - if r.url == self.config.repository_url: - remote = r - break - - if remote is None: - remote = self._repo.create_remote( - "added_by_vulnerablecode", url=self.config.repository_url - ) - - return remote - - def _update_from_remote(self, remote, branch) -> None: - fetch_info = remote.fetch() - if len(fetch_info) == 0: - return - branch = self._repo.branches[branch] - branch.set_reference(remote.refs[branch.name]) - self._repo.head.reset(index=True, working_tree=True) - - -def _include_file( - path: str, - subdir: Optional[str] = None, - recursive: bool = False, - file_ext: Optional[str] = None, -) -> bool: - match = True - - if subdir: - if not subdir.endswith(os.path.sep): - subdir = f"{subdir}{os.path.sep}" - - match = match and path.startswith(subdir) - - if not recursive: - match = match and (os.path.sep not in path[len(subdir or "") :]) - - if file_ext: - match = match and path.endswith(f".{file_ext}") + def clone(self): + try: + self.vcs_response = fetch_via_vcs(self.repo_url) + except Exception as e: + msg = f"Failed to fetch {self.repo_url} via vcs: {e}" + logger.error(msg) + raise ForkError(msg) from e - return match + def advisory_data(self) -> Iterable[AdvisoryData]: + """ + Return AdvisoryData objects corresponding to the data being imported + """ + raise NotImplementedError # TODO: Needs rewrite diff --git a/vulnerabilities/importers/__init__.py b/vulnerabilities/importers/__init__.py index b3e6063f0..8bb71686f 100644 --- a/vulnerabilities/importers/__init__.py +++ b/vulnerabilities/importers/__init__.py @@ -16,6 +16,7 @@ from vulnerabilities.importers import openssl from vulnerabilities.importers import pysec from vulnerabilities.importers import redhat +from vulnerabilities.importers import rust IMPORTERS_REGISTRY = [ nginx.NginxImporter, @@ -26,7 +27,8 @@ redhat.RedhatImporter, pysec.PyPIImporter, debian.DebianImporter, - gitlab.GitLabAPIImporter, + gitlab.GitLabGitImporter, + rust.RustImporter, ] IMPORTERS_REGISTRY = {x.qualified_name: x for x in IMPORTERS_REGISTRY} diff --git a/vulnerabilities/importers/gitlab.py b/vulnerabilities/importers/gitlab.py index 214c680cc..0016fb974 100644 --- a/vulnerabilities/importers/gitlab.py +++ b/vulnerabilities/importers/gitlab.py @@ -8,16 +8,15 @@ # import logging -import os import traceback from datetime import datetime +from pathlib import Path from typing import Iterable from typing import List from typing import Mapping from typing import Optional import pytz -import saneyaml from dateutil import parser as dateparser from django.db.models.query import QuerySet from fetchcode.vcs import fetch_via_vcs @@ -29,7 +28,7 @@ from vulnerabilities.importer import AdvisoryData from vulnerabilities.importer import AffectedPackage -from vulnerabilities.importer import Importer +from vulnerabilities.importer import GitImporter from vulnerabilities.importer import Reference from vulnerabilities.importer import UnMergeablePackageError from vulnerabilities.improver import Improver @@ -42,6 +41,7 @@ from vulnerabilities.utils import AffectedPackage as LegacyAffectedPackage from vulnerabilities.utils import build_description from vulnerabilities.utils import get_affected_packages_by_patched_package +from vulnerabilities.utils import load_yaml from vulnerabilities.utils import nearest_patched_package from vulnerabilities.utils import resolve_version_range @@ -71,31 +71,45 @@ def fork_and_get_dir(url): return fetch_via_vcs(url).dest_dir -class ForkError(Exception): - pass - - -class GitLabAPIImporter(Importer): +class GitLabGitImporter(GitImporter): spdx_license_expression = "MIT" license_url = "https://gitlab.com/gitlab-org/advisories-community/-/blob/main/LICENSE" - gitlab_url = "git+https://gitlab.com/gitlab-org/advisories-community/" + + def __init__(self): + super().__init__(repo_url="git+https://gitlab.com/gitlab-org/advisories-community/") def advisory_data(self) -> Iterable[AdvisoryData]: try: - fork_directory = fork_and_get_dir(url=self.gitlab_url) - except Exception as e: - logger.error(f"Can't clone url {self.gitlab_url}") - raise ForkError(self.gitlab_url) from e - for root_dir in os.listdir(fork_directory): - # skip well known files and directories that contain no advisory data - if root_dir in ("ci", "CODEOWNERS", "README.md", "LICENSE", ".git"): - continue - if root_dir not in PURL_TYPE_BY_GITLAB_SCHEME: - logger.error(f"Unknown package type: {root_dir}") - continue - for root, _, files in os.walk(os.path.join(fork_directory, root_dir)): - for file in files: - yield parse_gitlab_advisory(file=os.path.join(root, file)) + self.clone() + path = Path(self.vcs_response.dest_dir) + + glob = "**/*.yml" + files = (p for p in path.glob(glob) if p.is_file()) + for file in files: + purl_type = get_gitlab_package_type(path=file, root=path) + if not purl_type: + logger.error(f"Unknow gitlab directory structure {file!r}") + continue + + if purl_type in PURL_TYPE_BY_GITLAB_SCHEME: + yield parse_gitlab_advisory(file) + + else: + logger.error(f"Unknow package type {purl_type!r}") + continue + finally: + if self.vcs_response: + self.vcs_response.delete() + + +def get_gitlab_package_type(path: Path, root: Path): + """ + Return a package type extracted from a gitlab advisory path + """ + relative = path.relative_to(root) + parts = relative.parts + gitlab_schema = parts[0] + return gitlab_schema def get_purl(package_slug): @@ -168,10 +182,12 @@ def parse_gitlab_advisory(file): identifiers: - "GMS-2018-26" """ - with open(file, "r") as f: - gitlab_advisory = saneyaml.load(f) + gitlab_advisory = load_yaml(file) + if not isinstance(gitlab_advisory, dict): - logger.error(f"parse_yaml_file: yaml_file is not of type `dict`: {gitlab_advisory!r}") + logger.error( + f"parse_gitlab_advisory: unknown gitlab advisory format in {file!r} with data: {gitlab_advisory!r}" + ) return # refer to schema here https://gitlab.com/gitlab-org/advisories-community/-/blob/main/ci/schema/schema.json @@ -261,7 +277,7 @@ def __init__(self) -> None: @property def interesting_advisories(self) -> QuerySet: - return Advisory.objects.filter(created_by=GitLabAPIImporter.qualified_name) + return Advisory.objects.filter(created_by=GitLabGitImporter.qualified_name) def get_package_versions( self, package_url: PackageURL, until: Optional[datetime] = None diff --git a/vulnerabilities/importers/rust.py b/vulnerabilities/importers/rust.py index 701405128..893826411 100644 --- a/vulnerabilities/importers/rust.py +++ b/vulnerabilities/importers/rust.py @@ -7,8 +7,10 @@ # See https://aboutcode.org for more information about nexB OSS projects. # -import asyncio +import logging from itertools import chain +from pathlib import Path +from typing import Iterable from typing import List from typing import Optional from typing import Set @@ -18,6 +20,7 @@ import toml from dateutil.parser import parse from packageurl import PackageURL +from univers.version_range import CargoVersionRange from univers.version_range import VersionRange from univers.versions import SemverVersion @@ -25,57 +28,37 @@ from vulnerabilities.importer import GitImporter from vulnerabilities.importer import Reference from vulnerabilities.package_managers import CratesVersionAPI +from vulnerabilities.package_managers import PackageVersion from vulnerabilities.utils import nearest_patched_package +logger = logging.getLogger(__name__) -class RustImporter(GitImporter): - def __enter__(self): - super(RustImporter, self).__enter__() - - if not getattr(self, "_added_files", None): - self._added_files, self._updated_files = self.file_changes( - subdir="crates", # TODO Consider importing the advisories for cargo, etc as well. - recursive=True, - file_ext="md", - ) - @property - def crates_api(self): - if not hasattr(self, "_crates_api"): - setattr(self, "_crates_api", CratesVersionAPI()) - return self._crates_api - - def set_api(self, packages): - asyncio.run(self.crates_api.load_api(packages)) - - def updated_advisories(self) -> Set[AdvisoryData]: - return self._load_advisories(self._updated_files.union(self._added_files)) - - def _load_advisories(self, files) -> Set[AdvisoryData]: - # per @tarcieri It will always be named RUSTSEC-0000-0000.md - # https://github.com/nexB/vulnerablecode/pull/281/files#r528899864 - files = [f for f in files if not f.endswith("-0000.md")] # skip temporary files - packages = self.collect_packages(files) - self.set_api(packages) - - while files: - batch, files = files[: self.batch_size], files[self.batch_size :] - advisories = [] - for path in batch: - advisory = self._load_advisory(path) - if advisory: - advisories.append(advisory) - yield advisories - - def collect_packages(self, paths): - packages = set() - for path in paths: - record = get_advisory_data(path) - packages.add(record["advisory"]["package"]) - - return packages - - def _load_advisory(self, path: str) -> Optional[AdvisoryData]: +class RustImporter(GitImporter): + spdx_license_expression = "CC0-1.0" + license_url = "https://github.com/rustsec/advisory-db/blob/main/LICENSE.txt" + + def __init__(self): + super().__init__(repo_url="git+https://github.com/rustsec/advisory-db") + self.pkg_manager_api = CratesVersionAPI() + + def advisory_data(self) -> Iterable[AdvisoryData]: + try: + self.clone() + path = Path(self.vcs_response.dest_dir) + glob = "crates/**/*.md" + files = (p for p in path.glob(glob) if p.is_file()) + for file in files: + # per @tarcieri It will always be named RUSTSEC-0000-0000.md + # https://github.com/nexB/vulnerablecode/pull/281/files#r528899864 + if not file.stem.endswith("-0000"): # skip temporary files + # packages = collect_packages(files) + yield self.parse_rust_advisory(str(file)) + finally: + if self.vcs_response: + self.vcs_response.delete() + + def parse_rust_advisory(self, path: str) -> Optional[AdvisoryData]: record = get_advisory_data(path) advisory = record.get("advisory", {}) crate_name = advisory["package"] @@ -84,40 +67,36 @@ def _load_advisory(self, path: str) -> Optional[AdvisoryData]: references.append(Reference(url=advisory["url"])) publish_date = parse(advisory["date"]).replace(tzinfo=pytz.UTC) - all_versions = self.crates_api.get(crate_name, publish_date).valid_versions + all_versions = self.pkg_manager_api.fetch(crate_name) # FIXME: Avoid wildcard version ranges for now. # See https://github.com/RustSec/advisory-db/discussions/831 - affected_ranges = [ - VersionRange.from_scheme_version_spec_string("semver", r) - for r in chain.from_iterable(record.get("affected", {}).get("functions", {}).values()) - if r != "*" - ] - - unaffected_ranges = [ - VersionRange.from_scheme_version_spec_string("semver", r) - for r in record.get("versions", {}).get("unaffected", []) - if r != "*" - ] - resolved_ranges = [ - VersionRange.from_scheme_version_spec_string("semver", r) - for r in record.get("versions", {}).get("patched", []) - if r != "*" - ] - - unaffected, affected = categorize_versions( - all_versions, unaffected_ranges, affected_ranges, resolved_ranges - ) - - impacted_purls = [PackageURL(type="cargo", name=crate_name, version=v) for v in affected] - resolved_purls = [PackageURL(type="cargo", name=crate_name, version=v) for v in unaffected] - - cve_id = None - if "aliases" in advisory: - for alias in advisory["aliases"]: - if alias.startswith("CVE-"): - cve_id = alias - break + # affected_ranges = [ + # CargoVersionRange.from_natives(r) + # for r in chain.from_iterable(record.get("affected", {}).get("functions", {}).values()) + # if r != "*" + # ] + # + # unaffected_ranges = [ + # CargoVersionRange.from_natives(r) + # for r in record.get("versions", {}).get("unaffected", []) + # if r != "*" + # ] + # resolved_ranges = [ + # CargoVersionRange.from_natives(r) + # for r in record.get("versions", {}).get("patched", []) + # if r != "*" + # ] + # + # unaffected, affected = categorize_versions( + # all_versions, unaffected_ranges, affected_ranges, resolved_ranges + # ) + + # impacted_purls = [PackageURL(type="cargo", name=crate_name, version=v) for v in affected] + # resolved_purls = [PackageURL(type="cargo", name=crate_name, version=v) for v in unaffected] + + aliases = advisory.get("aliases") or [] + aliases.append(advisory.get("id")) references.append( Reference( @@ -126,16 +105,17 @@ def _load_advisory(self, path: str) -> Optional[AdvisoryData]: ) ) - return AdvisoryData( - summary=advisory.get("description", ""), - affected_packages=nearest_patched_package(impacted_purls, resolved_purls), - vulnerability_id=cve_id, + x = AdvisoryData( + aliases=aliases, + summary=advisory.get("description") or "", + # affected_packages=nearest_patched_package(impacted_purls, resolved_purls), references=references, ) + return x def categorize_versions( - all_versions: Set[str], + all_versions: Iterable[PackageVersion], unaffected_version_ranges: List[VersionRange], affected_version_ranges: List[VersionRange], resolved_version_ranges: List[VersionRange], @@ -157,19 +137,19 @@ def categorize_versions( # TODO: This is probably wrong for version in all_versions: - version_obj = SemverVersion(version) + version_obj = SemverVersion(version.value) if affected_version_ranges and all([version_obj in av for av in affected_version_ranges]): - affected.add(version) + affected.add(version.value) elif unaffected_version_ranges and all( [version_obj in av for av in unaffected_version_ranges] ): - unaffected.add(version) + unaffected.add(version.value) elif resolved_version_ranges and all([version_obj in av for av in resolved_version_ranges]): - unaffected.add(version) + unaffected.add(version.value) # If some versions were not classified above, one or more of the given ranges might be empty, so # the remaining versions default to either affected or unaffected. - uncategorized_versions = all_versions - unaffected.union(affected) + uncategorized_versions = [i.value for i in all_versions] - unaffected.union(affected) if uncategorized_versions: if not affected_version_ranges: affected.update(uncategorized_versions) @@ -239,3 +219,11 @@ def get_advisory_data(location): with open(location) as lines: toml_lines = get_toml_lines(lines) return data_from_toml_lines(toml_lines) + + +def collect_packages(paths): + packages = set() + for path in paths: + record = get_advisory_data(path) + packages.add(record["advisory"]["package"]) + return packages diff --git a/vulnerabilities/tests/test_data_source.py b/vulnerabilities/tests/test_data_source.py index fe17cafa7..8ba8503bf 100644 --- a/vulnerabilities/tests/test_data_source.py +++ b/vulnerabilities/tests/test_data_source.py @@ -7,24 +7,19 @@ # See https://aboutcode.org for more information about nexB OSS projects. # -import datetime import os -import shutil -import tempfile import xml.etree.ElementTree as ET -import zipfile +from typing import Iterable from unittest import TestCase -from unittest.mock import MagicMock -from unittest.mock import patch -import git import pytest from packageurl import PackageURL +from vulnerabilities.importer import AdvisoryData +from vulnerabilities.importer import ForkError from vulnerabilities.importer import GitImporter -from vulnerabilities.importer import InvalidConfigurationError +from vulnerabilities.importer import Importer from vulnerabilities.importer import OvalImporter -from vulnerabilities.importer import _include_file from vulnerabilities.oval_parser import OvalParser BASE_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -41,249 +36,6 @@ def load_oval_data(): return etrees_of_oval -@pytest.fixture -def clone_url(tmp_path): - git_dir = tmp_path / "git_dir" - repo = git.Repo.init(str(git_dir)) - new_file_path = str(git_dir / "file") - open(new_file_path, "wb").close() - repo.index.add([new_file_path]) - repo.index.commit("Added a new file") - try: - yield str(git_dir) - finally: - shutil.rmtree(git_dir) - - -@pytest.fixture -def clone_url2(tmp_path): - git_dir = tmp_path / "git_dir2" - repo = git.Repo.init(str(git_dir)) - new_file_path = str(git_dir / "file2") - open(new_file_path, "wb").close() - repo.index.add([new_file_path]) - repo.index.commit("Added a new file") - - try: - yield str(git_dir) - finally: - shutil.rmtree(git_dir) - - -def mk_ds(**kwargs): - # just for convenience, since this is a mandatory parameter we always pass a value - if "repository_url" not in kwargs: - kwargs["repository_url"] = "asdf" - - last_run_date = kwargs.pop("last_run_date", None) - cutoff_date = kwargs.pop("cutoff_date", None) - - # batch_size is a required parameter of the base class, unrelated to these tests - return GitImporter( - batch_size=100, last_run_date=last_run_date, cutoff_date=cutoff_date, config=kwargs - ) - - -def test_GitImporter_repository_url_required(no_mkdir, no_rmtree): - - with pytest.raises(InvalidConfigurationError): - GitImporter(batch_size=100) - - -def test_GitImporter_validate_configuration_create_working_directory_must_be_set_when_working_directory_is_default( - no_mkdir, no_rmtree -): - - with pytest.raises(InvalidConfigurationError): - mk_ds(create_working_directory=False) - - -def test_GitImporter_validate_configuration_remove_working_directory_must_be_set_when_working_directory_is_default( - no_mkdir, no_rmtree -): - - with pytest.raises(InvalidConfigurationError): - mk_ds(remove_working_directory=False) - - -@patch("os.path.exists", return_value=True) -def test_GitImporter_validate_configuration_remove_working_directory_is_applied( - no_mkdir, no_rmtree -): - - ds = mk_ds(remove_working_directory=False, working_directory="/some/directory") - - assert not ds.config.remove_working_directory - - -def test_GitImporter_validate_configuration_working_directory_must_exist_when_create_working_directory_is_not_set( - no_mkdir, no_rmtree -): - - with pytest.raises(InvalidConfigurationError): - mk_ds(working_directory="/does/not/exist", create_working_directory=False) - - -def test_GitImporter_contextmgr_working_directory_is_created_and_removed(tmp_path, clone_url): - - wd = tmp_path / "working" - ds = mk_ds( - working_directory=str(wd), - create_working_directory=True, - remove_working_directory=True, - repository_url=clone_url, - ) - - with ds: - assert str(wd) == ds.config.working_directory - assert (wd / ".git").exists() - assert (wd / "file").exists() - - assert not (wd / ".git").exists() - - -@patch("tempfile.mkdtemp") -def test_GitImporter_contextmgr_calls_mkdtemp_if_working_directory_is_not_set( - mkdtemp, tmp_path, clone_url -): - - mkdtemp.return_value = str(tmp_path / "working") - ds = mk_ds(repository_url=clone_url) - - with ds: - assert mkdtemp.called - assert ds.config.working_directory == str(tmp_path / "working") - - -def test_GitImporter_contextmgr_uses_existing_repository( - clone_url, - clone_url2, - no_mkdir, - no_rmtree, -): - ds = mk_ds( - working_directory=clone_url, - repository_url=clone_url2, - create_working_directory=False, - remove_working_directory=False, - ) - - with ds: - # also make sure we switch the branch (original do not have file2) - assert os.path.exists(os.path.join(ds.config.working_directory, "file2")) - - assert os.path.exists(ds.config.working_directory) - - -def test__include_file(): - - assert _include_file("foo.json", subdir=None, recursive=False, file_ext=None) - assert not _include_file("foo/bar.json", subdir=None, recursive=False, file_ext=None) - assert _include_file("foo/bar.json", subdir="foo/", recursive=False, file_ext=None) - assert _include_file("foo/bar.json", subdir="foo", recursive=False, file_ext=None) - assert not _include_file("foobar.json", subdir="foo", recursive=False, file_ext=None) - assert _include_file("foo/bar.json", subdir=None, recursive=True, file_ext=None) - assert not _include_file("foo/bar.json", subdir=None, recursive=True, file_ext="yaml") - assert _include_file("foo/bar/baz.json", subdir="foo", recursive=True, file_ext="json") - assert not _include_file("bar/foo/baz.json", subdir="foo", recursive=True, file_ext="json") - - -class GitImporterTest(TestCase): - - tempdir = None - - @classmethod - def setUpClass(cls) -> None: - cls.tempdir = tempfile.mkdtemp() - zip_path = os.path.join(TEST_DATA, "advisory-db.zip") - - with zipfile.ZipFile(zip_path, "r") as zip_ref: - zip_ref.extractall(cls.tempdir) - - @classmethod - def tearDownClass(cls) -> None: - shutil.rmtree(cls.tempdir) - - def setUp(self) -> None: - self.repodir = os.path.join(self.tempdir, "advisory-db") - - def mk_ds(self, **kwargs) -> GitImporter: - kwargs["working_directory"] = self.repodir - kwargs["create_working_directory"] = False - kwargs["remove_working_directory"] = False - - ds = mk_ds(**kwargs) - ds._update_from_remote = MagicMock() - return ds - - def test_file_changes_last_run_date_and_cutoff_date_is_None(self): - - ds = self.mk_ds(last_run_date=None, cutoff_date=None) - - with ds: - added_files, updated_files = ds.file_changes( - subdir="rust", recursive=True, file_ext="toml" - ) - - assert len(updated_files) == 0 - - assert set(added_files) == { - os.path.join(self.repodir, f) - for f in { - "rust/cargo/CVE-2019-16760.toml", - "rust/rustdoc/CVE-2018-1000622.toml", - "rust/std/CVE-2018-1000657.toml", - "rust/std/CVE-2018-1000810.toml", - "rust/std/CVE-2019-12083.toml", - } - } - - def test_file_changes_cutoff_date_is_now(self): - - ds = self.mk_ds(last_run_date=None, cutoff_date=datetime.datetime.now()) - - with ds: - added_files, updated_files = ds.file_changes( - subdir="cargo", recursive=True, file_ext="toml" - ) - - assert len(added_files) == 0 - assert len(updated_files) == 0 - - def test_file_changes_include_new_advisories(self): - - last_run_date = datetime.datetime(year=2020, month=3, day=29) - cutoff_date = last_run_date - datetime.timedelta(weeks=52 * 3) - ds = self.mk_ds(last_run_date=last_run_date, cutoff_date=cutoff_date) - - with ds: - added_files, updated_files = ds.file_changes( - subdir="crates", recursive=True, file_ext="toml" - ) - - assert len(added_files) >= 2 - assert os.path.join(self.repodir, "crates/bitvec/RUSTSEC-2020-0007.toml") in added_files - assert os.path.join(self.repodir, "crates/hyper/RUSTSEC-2020-0008.toml") in added_files - assert len(updated_files) == 0 - - def test_file_changes_include_fixed_advisories(self): - # pick a date that includes commit 9889ed0831b4fb4beb7675de361926d2e9a99c20 - # ("Fix patched version for RUSTSEC-2020-0008") - last_run_date = datetime.datetime( - year=2020, month=3, day=31, hour=17, minute=40, tzinfo=datetime.timezone.utc - ) - ds = self.mk_ds(last_run_date=last_run_date, cutoff_date=None) - - with ds: - added_files, updated_files = ds.file_changes( - subdir="crates", recursive=True, file_ext="toml" - ) - - assert len(added_files) == 0 - assert len(updated_files) == 1 - assert os.path.join(self.repodir, "crates/hyper/RUSTSEC-2020-0008.toml") in updated_files - - class TestOvalImporter(TestCase): @classmethod def setUpClass(cls): diff --git a/vulnerabilities/tests/test_gitlab.py b/vulnerabilities/tests/test_gitlab.py index bad3eae4f..f96483eec 100644 --- a/vulnerabilities/tests/test_gitlab.py +++ b/vulnerabilities/tests/test_gitlab.py @@ -9,12 +9,16 @@ import json import os +from pathlib import Path from unittest import mock import pytest +from packageurl import PackageURL from vulnerabilities.importer import AdvisoryData from vulnerabilities.importers.gitlab import GitLabBasicImprover +from vulnerabilities.importers.gitlab import get_gitlab_package_type +from vulnerabilities.importers.gitlab import get_purl from vulnerabilities.importers.gitlab import parse_gitlab_advisory from vulnerabilities.improvers.default import DefaultImprover from vulnerabilities.tests import util_tests @@ -84,3 +88,43 @@ def test_gitlab_improver(mock_response, pkg_type): inference = [data.to_dict() for data in improver.get_inferences(advisory)] result.extend(inference) util_tests.check_results_against_json(result, expected_file) + + +def test_get_purl(): + assert get_purl("nuget/MessagePack") == PackageURL(type="nuget", name="MessagePack") + assert get_purl("nuget/Microsoft.NETCore.App") == PackageURL( + type="nuget", name="Microsoft.NETCore.App" + ) + assert get_purl("npm/fresh") == PackageURL(type="npm", name="fresh") + + +def test_get_gitlab_package_type(): + assert ( + get_gitlab_package_type( + Path("/tmp/tmp9317bd5i/maven/com.google.gwt/gwt/CVE-2013-4204.yml"), + Path("/tmp/tmp9317bd5i/"), + ) + == "maven" + ) + assert ( + get_gitlab_package_type( + Path( + "/tmp/tmp9317bd5i/maven/io.projectreactor.netty/reactor-netty-http/CVE-2020-5404.yml" + ), + Path("/tmp/tmp9317bd5i/"), + ) + == "maven" + ) + assert ( + get_gitlab_package_type( + Path("/tmp/tmp9317bd5i/go/github.com/cloudflare/cfrpki/CVE-2021-3909.yml"), + Path("/tmp/tmp9317bd5i/"), + ) + == "go" + ) + assert ( + get_gitlab_package_type( + Path("/tmp/tmp9317bd5i/gem/rexml/CVE-2021-28965.yml"), Path("/tmp/tmp9317bd5i/") + ) + == "gem" + )