Skip to content

refactor: use Path objects instead of strings #5142

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cve_bin_tool/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,8 +830,7 @@ def main(argv=None):
error_mode=error_mode,
)

# if OLD_CACHE_DIR (from cvedb.py) exists, print warning
if Path(OLD_CACHE_DIR).exists():
if OLD_CACHE_DIR.exists():
LOGGER.warning(
f"Obsolete cache dir {OLD_CACHE_DIR} is no longer needed and can be removed."
)
Expand Down
3 changes: 1 addition & 2 deletions cve_bin_tool/cve_scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
from collections import defaultdict
from logging import Logger
from pathlib import Path
from string import ascii_lowercase
from typing import DefaultDict, Dict, List

Expand All @@ -31,7 +30,7 @@ class CVEScanner:
all_cve_version_info: Dict[str, VersionInfo]

RANGE_UNSET: str = ""
dbname: str = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
dbname: str = str(DISK_LOCATION_DEFAULT / DBNAME)
CONSOLE: Console = Console(file=sys.stderr, theme=cve_theme)
ALPHA_TO_NUM: Dict[str, int] = dict(zip(ascii_lowercase, range(26)))

Expand Down
3 changes: 1 addition & 2 deletions cve_bin_tool/data_sources/curl_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import json
import logging
from pathlib import Path

import aiohttp

Expand Down Expand Up @@ -66,7 +65,7 @@ async def download_curl_vulnerabilities(self, session: RateLimiter) -> None:
async with await session.get(self.DATA_SOURCE_LINK) as response:
response.raise_for_status()
self.vulnerability_data = await response.json()
path = Path(str(Path(self.cachedir) / "vuln.json"))
path = self.cachedir / "vuln.json"
filepath = path.resolve()
async with FileIO(filepath, "w") as f:
await f.write(json.dumps(self.vulnerability_data, indent=4))
Expand Down
22 changes: 9 additions & 13 deletions cve_bin_tool/data_sources/epss_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import csv
import gzip
import logging
import os
from datetime import datetime, timedelta
from io import StringIO
from pathlib import Path
Expand Down Expand Up @@ -34,8 +33,8 @@ def __init__(self, error_mode=ErrorMode.TruncTrace):
self.error_mode = error_mode
self.cachedir = self.CACHEDIR
self.backup_cachedir = self.BACKUPCACHEDIR
self.epss_path = str(Path(self.cachedir) / "epss")
self.file_name = os.path.join(self.epss_path, "epss_scores-current.csv")
self.epss_path = self.cachedir / "epss"
self.file_name = self.epss_path / "epss_scores-current.csv"
self.source_name = self.SOURCE

async def update_epss(self):
Expand All @@ -58,11 +57,11 @@ async def download_epss_data(self):
"""Downloads the EPSS CSV file and saves it to the local filesystem.
The download is only performed if the file is older than 24 hours.
"""
os.makedirs(self.epss_path, exist_ok=True)
self.epss_path.mkdir(parents=True, exist_ok=True)
# Check if the file exists
if os.path.exists(self.file_name):
if self.file_name.exists():
# Get the modification time of the file
modified_time = os.path.getmtime(self.file_name)
modified_time = self.file_name.stat().st_mtime
last_modified = datetime.fromtimestamp(modified_time)

# Calculate the time difference between now and the last modified time
Expand All @@ -80,8 +79,7 @@ async def download_epss_data(self):
decompressed_data = gzip.decompress(await response.read())

# Save the downloaded data to the file
with open(self.file_name, "wb") as file:
file.write(decompressed_data)
self.file_name.write_bytes(decompressed_data)

except aiohttp.ClientError as e:
self.LOGGER.error(f"An error occurred during updating epss {e}")
Expand All @@ -102,8 +100,7 @@ async def download_epss_data(self):
decompressed_data = gzip.decompress(await response.read())

# Save the downloaded data to the file
with open(self.file_name, "wb") as file:
file.write(decompressed_data)
self.file_name.write_bytes(decompressed_data)

except aiohttp.ClientError as e:
self.LOGGER.error(f"An error occurred during downloading epss {e}")
Expand All @@ -114,9 +111,8 @@ def parse_epss_data(self, file_path=None):
if file_path is None:
file_path = self.file_name

with open(file_path) as file:
# Read the content of the CSV file
decoded_data = file.read()
# Read the content of the CSV file
decoded_data = Path(file_path).read_text()

# Create a CSV reader to read the data from the decoded CSV content
reader = csv.reader(StringIO(decoded_data), delimiter=",")
Expand Down
9 changes: 4 additions & 5 deletions cve_bin_tool/data_sources/gad_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import io
import re
import zipfile
from pathlib import Path

import aiohttp
import yaml
Expand Down Expand Up @@ -39,7 +38,7 @@ def __init__(
):
self.cachedir = self.CACHEDIR
self.slugs = None
self.gad_path = str(Path(self.cachedir) / "gad")
self.gad_path = self.cachedir / "gad"
self.source_name = self.SOURCE

self.error_mode = error_mode
Expand Down Expand Up @@ -90,8 +89,8 @@ async def fetch_cves(self):

self.db = cvedb.CVEDB()

if not Path(self.gad_path).exists():
Path(self.gad_path).mkdir()
if not self.gad_path.exists():
self.gad_path.mkdir()
# As no data, force full update
self.incremental_update = False

Expand Down Expand Up @@ -155,7 +154,7 @@ async def fetch_cves(self):
async def update_cve_entries(self):
"""Updates CVE entries from CVEs in cache."""

p = Path(self.gad_path).glob("**/*")
p = self.gad_path.glob("**/*")
# Need to find files which are new to the cache
last_update_timestamp = (
self.time_of_last_update.timestamp()
Expand Down
15 changes: 5 additions & 10 deletions cve_bin_tool/data_sources/nvd_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import logging
import re
import sqlite3
from pathlib import Path

import aiohttp
from rich.progress import track
Expand All @@ -27,7 +26,6 @@
NVD_FILENAME_TEMPLATE,
)
from cve_bin_tool.error_handler import (
AttemptedToWriteOutsideCachedir,
CVEDataForYearNotInCache,
ErrorHandler,
ErrorMode,
Expand Down Expand Up @@ -78,7 +76,7 @@ def __init__(
self.source_name = self.SOURCE

# set up the db if needed
self.dbpath = str(Path(self.cachedir) / DBNAME)
self.dbpath = self.cachedir / DBNAME
self.connection: sqlite3.Connection | None = None
self.session = session
self.cve_count = -1
Expand Down Expand Up @@ -544,12 +542,9 @@ async def cache_update(
Update the cache for a single year of NVD data.
"""
filename = url.split("/")[-1]
# Ensure we only write to files within the cachedir
cache_path = Path(self.cachedir)
filepath = Path(str(cache_path / filename)).resolve()
if not str(filepath).startswith(str(cache_path.resolve())):
with ErrorHandler(mode=self.error_mode, logger=self.LOGGER):
raise AttemptedToWriteOutsideCachedir(filepath)
cache_path = self.cachedir
filepath = cache_path / filename

# Validate the contents of the cached file
if filepath.is_file():
# Validate the sha and write out
Expand Down Expand Up @@ -604,7 +599,7 @@ def load_nvd_year(self, year: int) -> dict[str, str | object]:
Return the dict of CVE data for the given year.
"""

filename = Path(self.cachedir) / self.NVDCVE_FILENAME_TEMPLATE.format(year)
filename = self.cachedir / self.NVDCVE_FILENAME_TEMPLATE.format(year)
# Check if file exists
if not filename.is_file():
with ErrorHandler(mode=self.error_mode, logger=self.LOGGER):
Expand Down
17 changes: 8 additions & 9 deletions cve_bin_tool/data_sources/osv_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import datetime
import io
import json
import os
import shutil
import zipfile
from pathlib import Path
Expand All @@ -25,7 +24,7 @@

def find_gsutil():
gsutil_path = shutil.which("gsutil")
if not os.path.exists(gsutil_path):
if not Path(gsutil_path).exists():
raise FileNotFoundError(
"gsutil not found. Did you need to install requirements or activate a venv where gsutil is installed?"
)
Expand All @@ -46,7 +45,7 @@ def __init__(
):
self.cachedir = self.CACHEDIR
self.ecosystems = None
self.osv_path = str(Path(self.cachedir) / "osv")
self.osv_path = self.cachedir / "osv"
self.source_name = self.SOURCE

self.error_mode = error_mode
Expand Down Expand Up @@ -104,7 +103,7 @@ async def get_ecosystem_incremental(self, ecosystem, time_of_last_update, sessio
tasks.append(task)

for r in await asyncio.gather(*tasks):
filepath = Path(self.osv_path) / (r.get("id") + ".json")
filepath = self.osv_path / (r.get("id") + ".json")
r = json.dumps(r)

async with FileIO(filepath, "w") as f:
Expand Down Expand Up @@ -149,9 +148,9 @@ async def get_totalfiles(self, ecosystem):

gsutil_path = find_gsutil() # use helper function
gs_file = self.gs_url + ecosystem + "/all.zip"
await aio_run_command([gsutil_path, "cp", gs_file, self.osv_path])
await aio_run_command([gsutil_path, "cp", gs_file, str(self.osv_path)])

zip_path = Path(self.osv_path) / "all.zip"
zip_path = self.osv_path / "all.zip"
totalfiles = 0

with zipfile.ZipFile(zip_path, "r") as z:
Expand All @@ -170,8 +169,8 @@ async def fetch_cves(self):

self.db = cvedb.CVEDB()

if not Path(self.osv_path).exists():
Path(self.osv_path).mkdir()
if not self.osv_path.exists():
self.osv_path.mkdir()
# As no data, force full update
self.incremental_update = False

Expand Down Expand Up @@ -230,7 +229,7 @@ async def fetch_cves(self):
async def update_cve_entries(self):
"""Updates CVE entries from CVEs in cache"""

p = Path(self.osv_path).glob("**/*")
p = self.osv_path.glob("**/*")
# Need to find files which are new to the cache
last_update_timestamp = (
self.time_of_last_update.timestamp()
Expand Down
7 changes: 3 additions & 4 deletions cve_bin_tool/data_sources/purl2cpe_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import zipfile
from io import BytesIO
from pathlib import Path

import aiohttp

Expand All @@ -25,7 +24,7 @@ def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
):
self.cachedir = self.CACHEDIR
self.purl2cpe_path = str(Path(self.cachedir) / "purl2cpe")
self.purl2cpe_path = self.cachedir / "purl2cpe"
self.source_name = self.SOURCE
self.error_mode = error_mode
self.incremental_update = incremental_update
Expand All @@ -36,8 +35,8 @@ async def fetch_cves(self):
"""Fetches PURL2CPE database and places it in purl2cpe_path."""
LOGGER.info("Getting PURL2CPE data...")

if not Path(self.purl2cpe_path).exists():
Path(self.purl2cpe_path).mkdir()
if not self.purl2cpe_path.exists():
self.purl2cpe_path.mkdir()

if not self.session:
connector = aiohttp.TCPConnector(limit_per_host=10)
Expand Down
11 changes: 5 additions & 6 deletions cve_bin_tool/data_sources/redhat_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import datetime
import json
from pathlib import Path

import aiohttp

Expand All @@ -28,7 +27,7 @@ def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
):
self.cachedir = self.CACHEDIR
self.redhat_path = str(Path(self.cachedir) / "redhat")
self.redhat_path = self.cachedir / "redhat"
self.source_name = self.SOURCE

self.error_mode = error_mode
Expand Down Expand Up @@ -57,7 +56,7 @@ async def store_data(self, content):
"""Asynchronously stores CVE data in separate JSON files, excluding entries without a CVE ID."""
for c in content:
if c["CVE"] != "":
filepath = Path(self.redhat_path) / (str(c["CVE"]) + ".json")
filepath = self.redhat_path / (str(c["CVE"]) + ".json")
r = json.dumps(c)
async with FileIO(filepath, "w") as f:
await f.write(r)
Expand All @@ -73,8 +72,8 @@ async def fetch_cves(self):

self.db = cvedb.CVEDB()

if not Path(self.redhat_path).exists():
Path(self.redhat_path).mkdir()
if not self.redhat_path.exists():
self.redhat_path.mkdir()
# As no data, force full update
self.incremental_update = False

Expand Down Expand Up @@ -121,7 +120,7 @@ async def fetch_cves(self):
async def update_cve_entries(self):
"""Updates CVE entries from CVEs in cache."""

p = Path(self.redhat_path).glob("**/*")
p = self.redhat_path.glob("**/*")
# Need to find files which are new to the cache
last_update_timestamp = (
self.time_of_last_update.timestamp()
Expand Down
9 changes: 4 additions & 5 deletions cve_bin_tool/data_sources/rsd_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import io
import json
import zipfile
from pathlib import Path

import aiohttp
from cvss import CVSS2, CVSS3
Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(
self, error_mode: ErrorMode = ErrorMode.TruncTrace, incremental_update=False
):
self.cachedir = self.CACHEDIR
self.rsd_path = str(Path(self.cachedir) / "rsd")
self.rsd_path = self.cachedir / "rsd"
self.source_name = self.SOURCE

self.error_mode = error_mode
Expand Down Expand Up @@ -71,8 +70,8 @@ async def fetch_cves(self):

self.db = cvedb.CVEDB()

if not Path(self.rsd_path).exists():
Path(self.rsd_path).mkdir()
if not self.rsd_path.exists():
self.rsd_path.mkdir()

if not self.session:
connector = aiohttp.TCPConnector(limit_per_host=19)
Expand Down Expand Up @@ -133,7 +132,7 @@ async def fetch_cves(self):
async def update_cve_entries(self):
"""Updates CVE entries from CVEs in cache."""

p = Path(self.rsd_path).glob("**/*")
p = self.rsd_path.glob("**/*")
# Need to find files which are new to the cache
last_update_timestamp = (
self.time_of_last_update.timestamp()
Expand Down
3 changes: 1 addition & 2 deletions cve_bin_tool/helper_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import textwrap
from collections import ChainMap
from logging import Logger
from pathlib import Path
from typing import MutableMapping

from rich import print as rprint
Expand Down Expand Up @@ -46,7 +45,7 @@ def __init__(

# for setting the database
self.connection = None
self.dbpath = str(Path(DISK_LOCATION_DEFAULT) / DBNAME)
self.dbpath = DISK_LOCATION_DEFAULT / DBNAME

# for extraction
self.walker = DirWalk().walk
Expand Down
Loading