diff --git a/cve_bin_tool/cli.py b/cve_bin_tool/cli.py index f7f520dc9c..012871da75 100644 --- a/cve_bin_tool/cli.py +++ b/cve_bin_tool/cli.py @@ -830,8 +830,7 @@ def main(argv=None): error_mode=error_mode, ) - # if OLD_CACHE_DIR (from cvedb.py) exists, print warning - if Path(OLD_CACHE_DIR).exists(): + if OLD_CACHE_DIR.exists(): LOGGER.warning( f"Obsolete cache dir {OLD_CACHE_DIR} is no longer needed and can be removed." ) diff --git a/cve_bin_tool/cve_scanner.py b/cve_bin_tool/cve_scanner.py index ae1fa9104d..feef513e5f 100644 --- a/cve_bin_tool/cve_scanner.py +++ b/cve_bin_tool/cve_scanner.py @@ -5,7 +5,6 @@ import sys from collections import defaultdict from logging import Logger -from pathlib import Path from string import ascii_lowercase from typing import DefaultDict, Dict, List @@ -31,7 +30,7 @@ class CVEScanner: all_cve_version_info: Dict[str, VersionInfo] RANGE_UNSET: str = "" - dbname: str = str(Path(DISK_LOCATION_DEFAULT) / DBNAME) + dbname: str = str(DISK_LOCATION_DEFAULT / DBNAME) CONSOLE: Console = Console(file=sys.stderr, theme=cve_theme) ALPHA_TO_NUM: Dict[str, int] = dict(zip(ascii_lowercase, range(26))) diff --git a/cve_bin_tool/data_sources/curl_source.py b/cve_bin_tool/data_sources/curl_source.py index 4466f6623f..24121d5cae 100644 --- a/cve_bin_tool/data_sources/curl_source.py +++ b/cve_bin_tool/data_sources/curl_source.py @@ -5,7 +5,6 @@ import json import logging -from pathlib import Path import aiohttp @@ -66,7 +65,7 @@ async def download_curl_vulnerabilities(self, session: RateLimiter) -> None: async with await session.get(self.DATA_SOURCE_LINK) as response: response.raise_for_status() self.vulnerability_data = await response.json() - path = Path(str(Path(self.cachedir) / "vuln.json")) + path = self.cachedir / "vuln.json" filepath = path.resolve() async with FileIO(filepath, "w") as f: await f.write(json.dumps(self.vulnerability_data, indent=4)) diff --git a/cve_bin_tool/data_sources/epss_source.py b/cve_bin_tool/data_sources/epss_source.py index b8fe1bad46..66eea372fc 100644 --- a/cve_bin_tool/data_sources/epss_source.py +++ b/cve_bin_tool/data_sources/epss_source.py @@ -6,7 +6,6 @@ import csv import gzip import logging -import os from datetime import datetime, timedelta from io import StringIO from pathlib import Path @@ -34,8 +33,8 @@ def __init__(self, error_mode=ErrorMode.TruncTrace): self.error_mode = error_mode self.cachedir = self.CACHEDIR self.backup_cachedir = self.BACKUPCACHEDIR - self.epss_path = str(Path(self.cachedir) / "epss") - self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv") + self.epss_path = self.cachedir / "epss" + self.file_name = self.epss_path / "epss_scores-current.csv" self.source_name = self.SOURCE async def update_epss(self): @@ -58,11 +57,11 @@ async def download_epss_data(self): """Downloads the EPSS CSV file and saves it to the local filesystem. The download is only performed if the file is older than 24 hours. """ - os.makedirs(self.epss_path, exist_ok=True) + self.epss_path.mkdir(parents=True, exist_ok=True) # Check if the file exists - if os.path.exists(self.file_name): + if self.file_name.exists(): # Get the modification time of the file - modified_time = os.path.getmtime(self.file_name) + modified_time = self.file_name.stat().st_mtime last_modified = datetime.fromtimestamp(modified_time) # Calculate the time difference between now and the last modified time @@ -80,8 +79,7 @@ async def download_epss_data(self): decompressed_data = gzip.decompress(await response.read()) # Save the downloaded data to the file - with open(self.file_name, "wb") as file: - file.write(decompressed_data) + self.file_name.write_bytes(decompressed_data) except aiohttp.ClientError as e: self.LOGGER.error(f"An error occurred during updating epss {e}") @@ -102,8 +100,7 @@ async def download_epss_data(self): decompressed_data = gzip.decompress(await response.read()) # Save the downloaded data to the file - with open(self.file_name, "wb") as file: - file.write(decompressed_data) + self.file_name.write_bytes(decompressed_data) except aiohttp.ClientError as e: self.LOGGER.error(f"An error occurred during downloading epss {e}") @@ -114,9 +111,8 @@ def parse_epss_data(self, file_path=None): if file_path is None: file_path = self.file_name - with open(file_path) as file: - # Read the content of the CSV file - decoded_data = file.read() + # Read the content of the CSV file + decoded_data = Path(file_path).read_text() # Create a CSV reader to read the data from the decoded CSV content reader = csv.reader(StringIO(decoded_data), delimiter=",") diff --git a/cve_bin_tool/data_sources/gad_source.py b/cve_bin_tool/data_sources/gad_source.py index 1f9da1c449..ff54eb17f4 100644 --- a/cve_bin_tool/data_sources/gad_source.py +++ b/cve_bin_tool/data_sources/gad_source.py @@ -8,7 +8,6 @@ import io import re import zipfile -from pathlib import Path import aiohttp import yaml @@ -39,7 +38,7 @@ def __init__( ): self.cachedir = self.CACHEDIR self.slugs = None - self.gad_path = str(Path(self.cachedir) / "gad") + self.gad_path = self.cachedir / "gad" self.source_name = self.SOURCE self.error_mode = error_mode @@ -90,8 +89,8 @@ async def fetch_cves(self): self.db = cvedb.CVEDB() - if not Path(self.gad_path).exists(): - Path(self.gad_path).mkdir() + if not self.gad_path.exists(): + self.gad_path.mkdir() # As no data, force full update self.incremental_update = False @@ -155,7 +154,7 @@ async def fetch_cves(self): async def update_cve_entries(self): """Updates CVE entries from CVEs in cache.""" - p = Path(self.gad_path).glob("**/*") + p = self.gad_path.glob("**/*") # Need to find files which are new to the cache last_update_timestamp = ( self.time_of_last_update.timestamp() diff --git a/cve_bin_tool/data_sources/nvd_source.py b/cve_bin_tool/data_sources/nvd_source.py index 4b865adcc1..c9e2bcd3be 100644 --- a/cve_bin_tool/data_sources/nvd_source.py +++ b/cve_bin_tool/data_sources/nvd_source.py @@ -13,7 +13,6 @@ import logging import re import sqlite3 -from pathlib import Path import aiohttp from rich.progress import track @@ -27,7 +26,6 @@ NVD_FILENAME_TEMPLATE, ) from cve_bin_tool.error_handler import ( - AttemptedToWriteOutsideCachedir, CVEDataForYearNotInCache, ErrorHandler, ErrorMode, @@ -78,7 +76,7 @@ def __init__( self.source_name = self.SOURCE # set up the db if needed - self.dbpath = str(Path(self.cachedir) / DBNAME) + self.dbpath = self.cachedir / DBNAME self.connection: sqlite3.Connection | None = None self.session = session self.cve_count = -1 @@ -544,12 +542,9 @@ async def cache_update( Update the cache for a single year of NVD data. """ filename = url.split("/")[-1] - # Ensure we only write to files within the cachedir - cache_path = Path(self.cachedir) - filepath = Path(str(cache_path / filename)).resolve() - if not str(filepath).startswith(str(cache_path.resolve())): - with ErrorHandler(mode=self.error_mode, logger=self.LOGGER): - raise AttemptedToWriteOutsideCachedir(filepath) + cache_path = self.cachedir + filepath = cache_path / filename + # Validate the contents of the cached file if filepath.is_file(): # Validate the sha and write out @@ -604,7 +599,7 @@ def load_nvd_year(self, year: int) -> dict[str, str | object]: Return the dict of CVE data for the given year. """ - filename = Path(self.cachedir) / self.NVDCVE_FILENAME_TEMPLATE.format(year) + filename = self.cachedir / self.NVDCVE_FILENAME_TEMPLATE.format(year) # Check if file exists if not filename.is_file(): with ErrorHandler(mode=self.error_mode, logger=self.LOGGER): diff --git a/cve_bin_tool/data_sources/osv_source.py b/cve_bin_tool/data_sources/osv_source.py index 8ccbe61552..59503ef6ac 100644 --- a/cve_bin_tool/data_sources/osv_source.py +++ b/cve_bin_tool/data_sources/osv_source.py @@ -7,7 +7,6 @@ import datetime import io import json -import os import shutil import zipfile from pathlib import Path @@ -25,7 +24,7 @@ def find_gsutil(): gsutil_path = shutil.which("gsutil") - if not os.path.exists(gsutil_path): + if not Path(gsutil_path).exists(): raise FileNotFoundError( "gsutil not found. Did you need to install requirements or activate a venv where gsutil is installed?" ) @@ -46,7 +45,7 @@ def __init__( ): self.cachedir = self.CACHEDIR self.ecosystems = None - self.osv_path = str(Path(self.cachedir) / "osv") + self.osv_path = self.cachedir / "osv" self.source_name = self.SOURCE self.error_mode = error_mode @@ -104,7 +103,7 @@ async def get_ecosystem_incremental(self, ecosystem, time_of_last_update, sessio tasks.append(task) for r in await asyncio.gather(*tasks): - filepath = Path(self.osv_path) / (r.get("id") + ".json") + filepath = self.osv_path / (r.get("id") + ".json") r = json.dumps(r) async with FileIO(filepath, "w") as f: @@ -149,9 +148,9 @@ async def get_totalfiles(self, ecosystem): gsutil_path = find_gsutil() # use helper function gs_file = self.gs_url + ecosystem + "/all.zip" - await aio_run_command([gsutil_path, "cp", gs_file, self.osv_path]) + await aio_run_command([gsutil_path, "cp", gs_file, str(self.osv_path)]) - zip_path = Path(self.osv_path) / "all.zip" + zip_path = self.osv_path / "all.zip" totalfiles = 0 with zipfile.ZipFile(zip_path, "r") as z: @@ -170,8 +169,8 @@ async def fetch_cves(self): self.db = cvedb.CVEDB() - if not Path(self.osv_path).exists(): - Path(self.osv_path).mkdir() + if not self.osv_path.exists(): + self.osv_path.mkdir() # As no data, force full update self.incremental_update = False @@ -230,7 +229,7 @@ async def fetch_cves(self): async def update_cve_entries(self): """Updates CVE entries from CVEs in cache""" - p = Path(self.osv_path).glob("**/*") + p = self.osv_path.glob("**/*") # Need to find files which are new to the cache last_update_timestamp = ( self.time_of_last_update.timestamp() diff --git a/cve_bin_tool/data_sources/purl2cpe_source.py b/cve_bin_tool/data_sources/purl2cpe_source.py index 5c3e35639a..40b573bb26 100644 --- a/cve_bin_tool/data_sources/purl2cpe_source.py +++ b/cve_bin_tool/data_sources/purl2cpe_source.py @@ -2,7 +2,6 @@ import zipfile from io import BytesIO -from pathlib import Path import aiohttp @@ -25,7 +24,7 @@ def __init__( self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False ): self.cachedir = self.CACHEDIR - self.purl2cpe_path = str(Path(self.cachedir) / "purl2cpe") + self.purl2cpe_path = self.cachedir / "purl2cpe" self.source_name = self.SOURCE self.error_mode = error_mode self.incremental_update = incremental_update @@ -36,8 +35,8 @@ async def fetch_cves(self): """Fetches PURL2CPE database and places it in purl2cpe_path.""" LOGGER.info("Getting PURL2CPE data...") - if not Path(self.purl2cpe_path).exists(): - Path(self.purl2cpe_path).mkdir() + if not self.purl2cpe_path.exists(): + self.purl2cpe_path.mkdir() if not self.session: connector = aiohttp.TCPConnector(limit_per_host=10) diff --git a/cve_bin_tool/data_sources/redhat_source.py b/cve_bin_tool/data_sources/redhat_source.py index d8386fa587..2f377882d0 100644 --- a/cve_bin_tool/data_sources/redhat_source.py +++ b/cve_bin_tool/data_sources/redhat_source.py @@ -3,7 +3,6 @@ import datetime import json -from pathlib import Path import aiohttp @@ -28,7 +27,7 @@ def __init__( self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False ): self.cachedir = self.CACHEDIR - self.redhat_path = str(Path(self.cachedir) / "redhat") + self.redhat_path = self.cachedir / "redhat" self.source_name = self.SOURCE self.error_mode = error_mode @@ -57,7 +56,7 @@ async def store_data(self, content): """Asynchronously stores CVE data in separate JSON files, excluding entries without a CVE ID.""" for c in content: if c["CVE"] != "": - filepath = Path(self.redhat_path) / (str(c["CVE"]) + ".json") + filepath = self.redhat_path / (str(c["CVE"]) + ".json") r = json.dumps(c) async with FileIO(filepath, "w") as f: await f.write(r) @@ -73,8 +72,8 @@ async def fetch_cves(self): self.db = cvedb.CVEDB() - if not Path(self.redhat_path).exists(): - Path(self.redhat_path).mkdir() + if not self.redhat_path.exists(): + self.redhat_path.mkdir() # As no data, force full update self.incremental_update = False @@ -121,7 +120,7 @@ async def fetch_cves(self): async def update_cve_entries(self): """Updates CVE entries from CVEs in cache.""" - p = Path(self.redhat_path).glob("**/*") + p = self.redhat_path.glob("**/*") # Need to find files which are new to the cache last_update_timestamp = ( self.time_of_last_update.timestamp() diff --git a/cve_bin_tool/data_sources/rsd_source.py b/cve_bin_tool/data_sources/rsd_source.py index 5edd246a34..0eb37c03a0 100644 --- a/cve_bin_tool/data_sources/rsd_source.py +++ b/cve_bin_tool/data_sources/rsd_source.py @@ -8,7 +8,6 @@ import io import json import zipfile -from pathlib import Path import aiohttp from cvss import CVSS2, CVSS3 @@ -36,7 +35,7 @@ def __init__( self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False ): self.cachedir = self.CACHEDIR - self.rsd_path = str(Path(self.cachedir) / "rsd") + self.rsd_path = self.cachedir / "rsd" self.source_name = self.SOURCE self.error_mode = error_mode @@ -71,8 +70,8 @@ async def fetch_cves(self): self.db = cvedb.CVEDB() - if not Path(self.rsd_path).exists(): - Path(self.rsd_path).mkdir() + if not self.rsd_path.exists(): + self.rsd_path.mkdir() if not self.session: connector = aiohttp.TCPConnector(limit_per_host=19) @@ -133,7 +132,7 @@ async def fetch_cves(self): async def update_cve_entries(self): """Updates CVE entries from CVEs in cache.""" - p = Path(self.rsd_path).glob("**/*") + p = self.rsd_path.glob("**/*") # Need to find files which are new to the cache last_update_timestamp = ( self.time_of_last_update.timestamp() diff --git a/cve_bin_tool/helper_script.py b/cve_bin_tool/helper_script.py index c3f51f454a..8dc01f0c81 100644 --- a/cve_bin_tool/helper_script.py +++ b/cve_bin_tool/helper_script.py @@ -9,7 +9,6 @@ import textwrap from collections import ChainMap from logging import Logger -from pathlib import Path from typing import MutableMapping from rich import print as rprint @@ -46,7 +45,7 @@ def __init__( # for setting the database self.connection = None - self.dbpath = str(Path(DISK_LOCATION_DEFAULT) / DBNAME) + self.dbpath = DISK_LOCATION_DEFAULT / DBNAME # for extraction self.walker = DirWalk().walk diff --git a/cve_bin_tool/version_signature.py b/cve_bin_tool/version_signature.py index cd10875b64..60bbf811f5 100644 --- a/cve_bin_tool/version_signature.py +++ b/cve_bin_tool/version_signature.py @@ -6,7 +6,6 @@ import sqlite3 import time from datetime import datetime -from pathlib import Path from cve_bin_tool.database_defaults import DISK_LOCATION_DEFAULT @@ -36,7 +35,7 @@ def __init__(self, table_name, mapping_function, duration) -> None: @property def dbname(self) -> str: """SQLite datebase file where the data is stored.""" - return str(Path(self.disk_location) / "version_map.db") + return str(self.disk_location / "version_map.db") def open(self) -> None: """Opens connection to sqlite database.""" diff --git a/test/test_cvedb.py b/test/test_cvedb.py index 80742ebbb4..c2f4850a08 100644 --- a/test/test_cvedb.py +++ b/test/test_cvedb.py @@ -4,6 +4,7 @@ import datetime import shutil import tempfile +from pathlib import Path from test.utils import EXTERNAL_SYSTEM, LONG_TESTS import pytest @@ -17,9 +18,9 @@ class TestCVEDB: @classmethod def setup_class(cls): cls.nvd = nvd_source.NVD_Source(nvd_type="json") - cachedir = tempfile.mkdtemp(prefix="cvedb-") - cls.exported_data = tempfile.mkdtemp(prefix="exported-data-") - cls.cvedb = cvedb.CVEDB(sources=[cls.nvd], cachedir=cachedir) + cachedir = Path(tempfile.mkdtemp(prefix="cvedb-")) + cls.exported_data = Path(tempfile.mkdtemp(prefix="exported-data-")) + cls.cvedb = cvedb.CVEDB(sources=[cls.nvd], cachedir=str(cachedir)) cls.nvd.cachedir = cachedir @classmethod @@ -39,7 +40,7 @@ async def test_refresh_nvd_json(self): @pytest.mark.skipif(not LONG_TESTS(), reason="Skipping long tests") def test_import_export_json(self): - main(["cve-bin-tool", "-u", "never", "--export", self.nvd.cachedir]) + main(["cve-bin-tool", "-u", "never", "--export", str(self.nvd.cachedir)]) cve_entries_check = "SELECT data_source, COUNT(*) as number FROM cve_severity GROUP BY data_source ORDER BY number DESC" cursor = self.cvedb.db_open_and_get_cursor() cursor.execute(cve_entries_check) diff --git a/test/test_json.py b/test/test_json.py index 7ab1225314..338bd88421 100644 --- a/test/test_json.py +++ b/test/test_json.py @@ -9,7 +9,6 @@ import datetime import gzip import json -from pathlib import Path from test.utils import EXTERNAL_SYSTEM, LONG_TESTS import pytest @@ -43,7 +42,7 @@ def test_json_validation(self, year): """Validate latest nvd json file against their published schema""" # Open the latest nvd file on disk with gzip.open( - Path(DISK_LOCATION_DEFAULT) / f"nvdcve-1.1-{year}.json.gz", + DISK_LOCATION_DEFAULT / f"nvdcve-1.1-{year}.json.gz", "rb", ) as json_file: nvd_json = json.loads(json_file.read()) diff --git a/test/test_source_gad.py b/test/test_source_gad.py index f68aac9464..df2122a287 100644 --- a/test/test_source_gad.py +++ b/test/test_source_gad.py @@ -18,8 +18,8 @@ class TestSourceGAD: @classmethod def setup_class(cls): cls.gad = gad_source.GAD_Source() - cls.gad.cachedir = tempfile.mkdtemp(prefix="cvedb-") - cls.gad.gad_path = str(Path(cls.gad.cachedir) / "gad") + cls.gad.cachedir = Path(tempfile.mkdtemp(prefix="cvedb-")) + cls.gad.gad_path = cls.gad.cachedir / "gad" @classmethod def teardown_class(cls): diff --git a/test/test_source_nvd.py b/test/test_source_nvd.py index a3d99b5677..100175a1e1 100644 --- a/test/test_source_nvd.py +++ b/test/test_source_nvd.py @@ -3,6 +3,7 @@ import shutil import tempfile +from pathlib import Path from test.utils import EXTERNAL_SYSTEM import aiohttp @@ -15,7 +16,7 @@ class TestSourceNVD: @classmethod def setup_class(cls): cls.nvd = nvd_source.NVD_Source() - cls.nvd.cachedir = tempfile.mkdtemp(prefix="cvedb-") + cls.nvd.cachedir = Path(tempfile.mkdtemp(prefix="cvedb-")) @classmethod def teardown_class(cls): diff --git a/test/test_source_osv.py b/test/test_source_osv.py index da3c599c2b..6fb509bada 100644 --- a/test/test_source_osv.py +++ b/test/test_source_osv.py @@ -20,8 +20,8 @@ class TestSourceOSV: @classmethod def setup_class(cls): cls.osv = osv_source.OSV_Source() - cls.osv.cachedir = tempfile.mkdtemp(prefix="cvedb-") - cls.osv.osv_path = str(Path(cls.osv.cachedir) / "osv") + cls.osv.cachedir = Path(tempfile.mkdtemp(prefix="cvedb-")) + cls.osv.osv_path = cls.osv.cachedir / "osv" @classmethod def teardown_class(cls): @@ -228,7 +228,7 @@ async def test_fetch_cves(self): await self.osv.fetch_cves() - p = Path(self.osv.osv_path).glob("**/*") + p = self.osv.osv_path.glob("**/*") files = [x.name for x in p if x.is_file()] # Check some files have been processed diff --git a/test/test_source_purl2cpe.py b/test/test_source_purl2cpe.py index c28b16deb9..f7f0ce4b3c 100644 --- a/test/test_source_purl2cpe.py +++ b/test/test_source_purl2cpe.py @@ -18,8 +18,8 @@ class TestSourceOSV: @classmethod def setup_class(cls): cls.purl2cpe = purl2cpe_source.PURL2CPE_Source() - cls.purl2cpe.cachedir = tempfile.mkdtemp(prefix="cvedb-") - cls.purl2cpe.purl2cpe_path = str(Path(cls.purl2cpe.cachedir) / "purl2cpe") + cls.purl2cpe.cachedir = Path(tempfile.mkdtemp(prefix="cvedb-")) + cls.purl2cpe.purl2cpe_path = cls.purl2cpe.cachedir / "purl2cpe" cls.local_path = Path("~").expanduser() / ".cache" / "cve-bin-tool" / "purl2cpe" @classmethod