diff --git a/crates/socket-patch-cli/Cargo.toml b/crates/socket-patch-cli/Cargo.toml index 917946e..439b02c 100644 --- a/crates/socket-patch-cli/Cargo.toml +++ b/crates/socket-patch-cli/Cargo.toml @@ -22,6 +22,10 @@ uuid = { workspace = true } regex = { workspace = true } tempfile = { workspace = true } +[features] +default = [] +cargo = ["socket-patch-core/cargo"] + [dev-dependencies] sha2 = { workspace = true } hex = { workspace = true } diff --git a/crates/socket-patch-cli/src/commands/apply.rs b/crates/socket-patch-cli/src/commands/apply.rs index 24aeb7d..cf196bd 100644 --- a/crates/socket-patch-cli/src/commands/apply.rs +++ b/crates/socket-patch-cli/src/commands/apply.rs @@ -4,15 +4,17 @@ use socket_patch_core::api::blob_fetcher::{ }; use socket_patch_core::api::client::get_api_client_from_env; use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; -use socket_patch_core::crawlers::{CrawlerOptions, NpmCrawler, PythonCrawler}; +use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem}; use socket_patch_core::manifest::operations::read_manifest; use socket_patch_core::patch::apply::{apply_package_patch, verify_file_patch, ApplyResult}; use socket_patch_core::utils::cleanup_blobs::{cleanup_unused_blobs, format_cleanup_result}; -use socket_patch_core::utils::purl::{is_npm_purl, is_pypi_purl, strip_purl_qualifiers}; +use socket_patch_core::utils::purl::strip_purl_qualifiers; use socket_patch_core::utils::telemetry::{track_patch_applied, track_patch_apply_failed}; use std::collections::{HashMap, HashSet}; use std::path::{Path, PathBuf}; +use crate::ecosystem_dispatch::{find_packages_for_purls, partition_purls}; + #[derive(Args)] pub struct ApplyArgs { /// Working directory @@ -173,18 +175,8 @@ async fn apply_patches_inner( // Partition manifest PURLs by ecosystem let manifest_purls: Vec = manifest.patches.keys().cloned().collect(); - let mut npm_purls: Vec = manifest_purls.iter().filter(|p| is_npm_purl(p)).cloned().collect(); - let mut pypi_purls: Vec = manifest_purls.iter().filter(|p| is_pypi_purl(p)).cloned().collect(); - - // Filter by ecosystem if specified - if let Some(ref ecosystems) = args.ecosystems { - if !ecosystems.iter().any(|e| e == "npm") { - npm_purls.clear(); - } - if !ecosystems.iter().any(|e| e == "pypi") { - pypi_purls.clear(); - } - } + let partitioned = + partition_purls(&manifest_purls, args.ecosystems.as_deref()); let crawler_options = CrawlerOptions { cwd: args.cwd.clone(), @@ -193,63 +185,12 @@ async fn apply_patches_inner( batch_size: 100, }; - let mut all_packages: HashMap = HashMap::new(); + let all_packages = + find_packages_for_purls(&partitioned, &crawler_options, args.silent).await; - // Find npm packages - if !npm_purls.is_empty() { - let npm_crawler = NpmCrawler; - match npm_crawler.get_node_modules_paths(&crawler_options).await { - Ok(nm_paths) => { - if (args.global || args.global_prefix.is_some()) && !args.silent { - if let Some(first) = nm_paths.first() { - println!("Using global npm packages at: {}", first.display()); - } - } - for nm_path in &nm_paths { - if let Ok(packages) = npm_crawler.find_by_purls(nm_path, &npm_purls).await { - for (purl, pkg) in packages { - all_packages.entry(purl).or_insert(pkg.path); - } - } - } - } - Err(e) => { - if !args.silent { - eprintln!("Failed to find npm packages: {e}"); - } - } - } - } + let has_any_purls = !partitioned.is_empty(); - // Find Python packages - if !pypi_purls.is_empty() { - let python_crawler = PythonCrawler; - let base_pypi_purls: Vec = pypi_purls - .iter() - .map(|p| strip_purl_qualifiers(p).to_string()) - .collect::>() - .into_iter() - .collect(); - - match python_crawler.get_site_packages_paths(&crawler_options).await { - Ok(sp_paths) => { - for sp_path in &sp_paths { - if let Ok(packages) = python_crawler.find_by_purls(sp_path, &base_pypi_purls).await { - for (purl, pkg) in packages { - all_packages.entry(purl).or_insert(pkg.path); - } - } - } - } - Err(e) => { - if !args.silent { - eprintln!("Failed to find Python packages: {e}"); - } - } - } - } - - if all_packages.is_empty() && npm_purls.is_empty() && pypi_purls.is_empty() { + if all_packages.is_empty() && !has_any_purls { if !args.silent { if args.global || args.global_prefix.is_some() { eprintln!("No global packages found"); @@ -271,20 +212,22 @@ async fn apply_patches_inner( let mut results: Vec = Vec::new(); let mut has_errors = false; - // Group pypi PURLs by base + // Group pypi PURLs by base (for variant matching with qualifiers) let mut pypi_qualified_groups: HashMap> = HashMap::new(); - for purl in &pypi_purls { - let base = strip_purl_qualifiers(purl).to_string(); - pypi_qualified_groups - .entry(base) - .or_default() - .push(purl.clone()); + if let Some(pypi_purls) = partitioned.get(&Ecosystem::Pypi) { + for purl in pypi_purls { + let base = strip_purl_qualifiers(purl).to_string(); + pypi_qualified_groups + .entry(base) + .or_default() + .push(purl.clone()); + } } let mut applied_base_purls: HashSet = HashSet::new(); for (purl, pkg_path) in &all_packages { - if is_pypi_purl(purl) { + if Ecosystem::from_purl(purl) == Some(Ecosystem::Pypi) { let base_purl = strip_purl_qualifiers(purl).to_string(); if applied_base_purls.contains(&base_purl) { continue; diff --git a/crates/socket-patch-cli/src/commands/get.rs b/crates/socket-patch-cli/src/commands/get.rs index 30987be..fb75ed4 100644 --- a/crates/socket-patch-cli/src/commands/get.rs +++ b/crates/socket-patch-cli/src/commands/get.rs @@ -2,7 +2,7 @@ use clap::Args; use regex::Regex; use socket_patch_core::api::client::get_api_client_from_env; use socket_patch_core::api::types::{PatchSearchResult, SearchResponse}; -use socket_patch_core::crawlers::{CrawlerOptions, NpmCrawler, PythonCrawler}; +use socket_patch_core::crawlers::CrawlerOptions; use socket_patch_core::manifest::operations::{read_manifest, write_manifest}; use socket_patch_core::manifest::schema::{ PatchFileInfo, PatchManifest, PatchRecord, VulnerabilityInfo, @@ -13,6 +13,8 @@ use std::collections::HashMap; use std::io::{self, Write}; use std::path::PathBuf; +use crate::ecosystem_dispatch::crawl_all_ecosystems; + #[derive(Args)] pub struct GetArgs { /// Patch identifier (UUID, CVE ID, GHSA ID, PURL, or package name) @@ -236,18 +238,17 @@ pub async fn run(args: GetArgs) -> i32 { global_prefix: args.global_prefix.clone(), batch_size: 100, }; - let npm_crawler = NpmCrawler; - let python_crawler = PythonCrawler; - let npm_packages = npm_crawler.crawl_all(&crawler_options).await; - let python_packages = python_crawler.crawl_all(&crawler_options).await; - let mut all_packages = npm_packages; - all_packages.extend(python_packages); + let (all_packages, _) = crawl_all_ecosystems(&crawler_options).await; if all_packages.is_empty() { if args.global { println!("No global packages found."); } else { - println!("No packages found. Run npm/yarn/pnpm/pip install first."); + #[cfg(feature = "cargo")] + let install_cmds = "npm/yarn/pnpm/pip/cargo"; + #[cfg(not(feature = "cargo"))] + let install_cmds = "npm/yarn/pnpm/pip"; + println!("No packages found. Run {install_cmds} install first."); } return 0; } diff --git a/crates/socket-patch-cli/src/commands/rollback.rs b/crates/socket-patch-cli/src/commands/rollback.rs index c09f161..4f1bb62 100644 --- a/crates/socket-patch-cli/src/commands/rollback.rs +++ b/crates/socket-patch-cli/src/commands/rollback.rs @@ -4,16 +4,16 @@ use socket_patch_core::api::blob_fetcher::{ }; use socket_patch_core::api::client::get_api_client_from_env; use socket_patch_core::constants::DEFAULT_PATCH_MANIFEST_PATH; -use socket_patch_core::crawlers::{CrawlerOptions, NpmCrawler, PythonCrawler}; +use socket_patch_core::crawlers::CrawlerOptions; use socket_patch_core::manifest::operations::read_manifest; use socket_patch_core::manifest::schema::{PatchManifest, PatchRecord}; use socket_patch_core::patch::rollback::{rollback_package_patch, RollbackResult}; -use socket_patch_core::utils::global_packages::get_global_prefix; -use socket_patch_core::utils::purl::{is_pypi_purl, strip_purl_qualifiers}; use socket_patch_core::utils::telemetry::{track_patch_rolled_back, track_patch_rollback_failed}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use std::path::{Path, PathBuf}; +use crate::ecosystem_dispatch::{find_packages_for_rollback, partition_purls}; + #[derive(Args)] pub struct RollbackArgs { /// Package PURL or patch UUID to rollback. Omit to rollback all patches. @@ -335,17 +335,8 @@ async fn rollback_patches_inner( // Partition PURLs by ecosystem let rollback_purls: Vec = patches_to_rollback.iter().map(|p| p.purl.clone()).collect(); - let mut npm_purls: Vec = rollback_purls.iter().filter(|p| !is_pypi_purl(p)).cloned().collect(); - let mut pypi_purls: Vec = rollback_purls.iter().filter(|p| is_pypi_purl(p)).cloned().collect(); - - if let Some(ref ecosystems) = args.ecosystems { - if !ecosystems.iter().any(|e| e == "npm") { - npm_purls.clear(); - } - if !ecosystems.iter().any(|e| e == "pypi") { - pypi_purls.clear(); - } - } + let partitioned = + partition_purls(&rollback_purls, args.ecosystems.as_deref()); let crawler_options = CrawlerOptions { cwd: args.cwd.clone(), @@ -354,70 +345,8 @@ async fn rollback_patches_inner( batch_size: 100, }; - let mut all_packages: HashMap = HashMap::new(); - - // Find npm packages - if !npm_purls.is_empty() { - if args.global || args.global_prefix.is_some() { - match get_global_prefix(args.global_prefix.as_ref().map(|p| p.to_str().unwrap_or(""))) { - Ok(prefix) => { - if !args.silent { - println!("Using global npm packages at: {prefix}"); - } - let npm_crawler = NpmCrawler; - if let Ok(packages) = npm_crawler.find_by_purls(Path::new(&prefix), &npm_purls).await { - for (purl, pkg) in packages { - all_packages.entry(purl).or_insert(pkg.path); - } - } - } - Err(e) => { - if !args.silent { - eprintln!("Failed to find global npm packages: {e}"); - } - return Ok((false, Vec::new())); - } - } - } else { - let npm_crawler = NpmCrawler; - if let Ok(nm_paths) = npm_crawler.get_node_modules_paths(&crawler_options).await { - for nm_path in &nm_paths { - if let Ok(packages) = npm_crawler.find_by_purls(nm_path, &npm_purls).await { - for (purl, pkg) in packages { - all_packages.entry(purl).or_insert(pkg.path); - } - } - } - } - } - } - - // Find Python packages - if !pypi_purls.is_empty() { - let python_crawler = PythonCrawler; - let base_pypi_purls: Vec = pypi_purls - .iter() - .map(|p| strip_purl_qualifiers(p).to_string()) - .collect::>() - .into_iter() - .collect(); - - if let Ok(sp_paths) = python_crawler.get_site_packages_paths(&crawler_options).await { - for sp_path in &sp_paths { - if let Ok(packages) = python_crawler.find_by_purls(sp_path, &base_pypi_purls).await { - for (base_purl, pkg) in packages { - for qualified_purl in &pypi_purls { - if strip_purl_qualifiers(qualified_purl) == base_purl - && !all_packages.contains_key(qualified_purl) - { - all_packages.insert(qualified_purl.clone(), pkg.path.clone()); - } - } - } - } - } - } - } + let all_packages = + find_packages_for_rollback(&partitioned, &crawler_options, args.silent).await; if all_packages.is_empty() { if !args.silent { diff --git a/crates/socket-patch-cli/src/commands/scan.rs b/crates/socket-patch-cli/src/commands/scan.rs index 81478aa..cf7f463 100644 --- a/crates/socket-patch-cli/src/commands/scan.rs +++ b/crates/socket-patch-cli/src/commands/scan.rs @@ -1,10 +1,12 @@ use clap::Args; use socket_patch_core::api::client::get_api_client_from_env; use socket_patch_core::api::types::BatchPackagePatches; -use socket_patch_core::crawlers::{CrawlerOptions, NpmCrawler, PythonCrawler}; +use socket_patch_core::crawlers::{CrawlerOptions, Ecosystem}; use std::collections::HashSet; use std::path::PathBuf; +use crate::ecosystem_dispatch::crawl_all_ecosystems; + const DEFAULT_BATCH_SIZE: usize = 100; #[derive(Args)] @@ -82,23 +84,10 @@ pub async fn run(args: ScanArgs) -> i32 { } // Crawl packages - let npm_crawler = NpmCrawler; - let python_crawler = PythonCrawler; - - let npm_packages = npm_crawler.crawl_all(&crawler_options).await; - let python_packages = python_crawler.crawl_all(&crawler_options).await; - - let mut all_purls: Vec = Vec::new(); - for pkg in &npm_packages { - all_purls.push(pkg.purl.clone()); - } - for pkg in &python_packages { - all_purls.push(pkg.purl.clone()); - } + let (all_crawled, eco_counts) = crawl_all_ecosystems(&crawler_options).await; + let all_purls: Vec = all_crawled.iter().map(|p| p.purl.clone()).collect(); let package_count = all_purls.len(); - let npm_count = npm_packages.len(); - let python_count = python_packages.len(); if package_count == 0 { if !args.json { @@ -121,18 +110,22 @@ pub async fn run(args: ScanArgs) -> i32 { } else if args.global || args.global_prefix.is_some() { println!("No global packages found."); } else { - println!("No packages found. Run npm/yarn/pnpm/pip install first."); + #[cfg(feature = "cargo")] + let install_cmds = "npm/yarn/pnpm/pip/cargo"; + #[cfg(not(feature = "cargo"))] + let install_cmds = "npm/yarn/pnpm/pip"; + println!("No packages found. Run {install_cmds} install first."); } return 0; } // Build ecosystem summary let mut eco_parts = Vec::new(); - if npm_count > 0 { - eco_parts.push(format!("{npm_count} npm")); - } - if python_count > 0 { - eco_parts.push(format!("{python_count} python")); + for eco in Ecosystem::all() { + let count = eco_counts.get(eco).copied().unwrap_or(0); + if count > 0 { + eco_parts.push(format!("{count} {}", eco.display_name())); + } } let eco_summary = if eco_parts.is_empty() { String::new() diff --git a/crates/socket-patch-cli/src/ecosystem_dispatch.rs b/crates/socket-patch-cli/src/ecosystem_dispatch.rs new file mode 100644 index 0000000..df7f438 --- /dev/null +++ b/crates/socket-patch-cli/src/ecosystem_dispatch.rs @@ -0,0 +1,266 @@ +use socket_patch_core::crawlers::{ + CrawledPackage, CrawlerOptions, Ecosystem, NpmCrawler, PythonCrawler, +}; +use socket_patch_core::utils::purl::strip_purl_qualifiers; +use std::collections::{HashMap, HashSet}; +use std::path::PathBuf; + +#[cfg(feature = "cargo")] +use socket_patch_core::crawlers::CargoCrawler; + +/// Partition PURLs by ecosystem, filtering by the `--ecosystems` flag if set. +pub fn partition_purls( + purls: &[String], + allowed_ecosystems: Option<&[String]>, +) -> HashMap> { + let mut map: HashMap> = HashMap::new(); + + for purl in purls { + if let Some(eco) = Ecosystem::from_purl(purl) { + if let Some(allowed) = allowed_ecosystems { + if !allowed.iter().any(|a| a == eco.cli_name()) { + continue; + } + } + map.entry(eco).or_default().push(purl.clone()); + } + } + + map +} + +/// For each ecosystem in the partitioned map, create the crawler, discover +/// source paths, and look up the given PURLs. Returns a unified +/// `purl -> path` map. +pub async fn find_packages_for_purls( + partitioned: &HashMap>, + options: &CrawlerOptions, + silent: bool, +) -> HashMap { + let mut all_packages: HashMap = HashMap::new(); + + // npm + if let Some(npm_purls) = partitioned.get(&Ecosystem::Npm) { + if !npm_purls.is_empty() { + let npm_crawler = NpmCrawler; + match npm_crawler.get_node_modules_paths(options).await { + Ok(nm_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = nm_paths.first() { + println!("Using global npm packages at: {}", first.display()); + } + } + for nm_path in &nm_paths { + if let Ok(packages) = npm_crawler.find_by_purls(nm_path, npm_purls).await { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find npm packages: {e}"); + } + } + } + } + } + + // pypi — deduplicate by base PURL (stripping qualifiers) + if let Some(pypi_purls) = partitioned.get(&Ecosystem::Pypi) { + if !pypi_purls.is_empty() { + let python_crawler = PythonCrawler; + let base_pypi_purls: Vec = pypi_purls + .iter() + .map(|p| strip_purl_qualifiers(p).to_string()) + .collect::>() + .into_iter() + .collect(); + + match python_crawler.get_site_packages_paths(options).await { + Ok(sp_paths) => { + for sp_path in &sp_paths { + if let Ok(packages) = + python_crawler.find_by_purls(sp_path, &base_pypi_purls).await + { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Python packages: {e}"); + } + } + } + } + } + + // cargo + #[cfg(feature = "cargo")] + if let Some(cargo_purls) = partitioned.get(&Ecosystem::Cargo) { + if !cargo_purls.is_empty() { + let cargo_crawler = CargoCrawler; + match cargo_crawler.get_crate_source_paths(options).await { + Ok(src_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = src_paths.first() { + println!("Using cargo crate sources at: {}", first.display()); + } + } + for src_path in &src_paths { + if let Ok(packages) = + cargo_crawler.find_by_purls(src_path, cargo_purls).await + { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Cargo crates: {e}"); + } + } + } + } + } + + all_packages +} + +/// Crawl all enabled ecosystems and return all packages plus per-ecosystem counts. +pub async fn crawl_all_ecosystems( + options: &CrawlerOptions, +) -> (Vec, HashMap) { + let mut all_packages = Vec::new(); + let mut counts: HashMap = HashMap::new(); + + let npm_crawler = NpmCrawler; + let npm_packages = npm_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Npm, npm_packages.len()); + all_packages.extend(npm_packages); + + let python_crawler = PythonCrawler; + let python_packages = python_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Pypi, python_packages.len()); + all_packages.extend(python_packages); + + #[cfg(feature = "cargo")] + { + let cargo_crawler = CargoCrawler; + let cargo_packages = cargo_crawler.crawl_all(options).await; + counts.insert(Ecosystem::Cargo, cargo_packages.len()); + all_packages.extend(cargo_packages); + } + + (all_packages, counts) +} + +/// Variant of `find_packages_for_purls` for rollback, which needs to remap +/// pypi qualified PURLs (with `?artifact_id=...`) to the base PURL found +/// by the crawler. +pub async fn find_packages_for_rollback( + partitioned: &HashMap>, + options: &CrawlerOptions, + silent: bool, +) -> HashMap { + let mut all_packages: HashMap = HashMap::new(); + + // npm + if let Some(npm_purls) = partitioned.get(&Ecosystem::Npm) { + if !npm_purls.is_empty() { + let npm_crawler = NpmCrawler; + match npm_crawler.get_node_modules_paths(options).await { + Ok(nm_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = nm_paths.first() { + println!("Using global npm packages at: {}", first.display()); + } + } + for nm_path in &nm_paths { + if let Ok(packages) = npm_crawler.find_by_purls(nm_path, npm_purls).await { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find npm packages: {e}"); + } + } + } + } + } + + // pypi — remap qualified PURLs to found base PURLs + if let Some(pypi_purls) = partitioned.get(&Ecosystem::Pypi) { + if !pypi_purls.is_empty() { + let python_crawler = PythonCrawler; + let base_pypi_purls: Vec = pypi_purls + .iter() + .map(|p| strip_purl_qualifiers(p).to_string()) + .collect::>() + .into_iter() + .collect(); + + if let Ok(sp_paths) = python_crawler.get_site_packages_paths(options).await { + for sp_path in &sp_paths { + if let Ok(packages) = + python_crawler.find_by_purls(sp_path, &base_pypi_purls).await + { + for (base_purl, pkg) in packages { + for qualified_purl in pypi_purls { + if strip_purl_qualifiers(qualified_purl) == base_purl + && !all_packages.contains_key(qualified_purl) + { + all_packages + .insert(qualified_purl.clone(), pkg.path.clone()); + } + } + } + } + } + } + } + } + + // cargo + #[cfg(feature = "cargo")] + if let Some(cargo_purls) = partitioned.get(&Ecosystem::Cargo) { + if !cargo_purls.is_empty() { + let cargo_crawler = CargoCrawler; + match cargo_crawler.get_crate_source_paths(options).await { + Ok(src_paths) => { + if (options.global || options.global_prefix.is_some()) && !silent { + if let Some(first) = src_paths.first() { + println!("Using cargo crate sources at: {}", first.display()); + } + } + for src_path in &src_paths { + if let Ok(packages) = + cargo_crawler.find_by_purls(src_path, cargo_purls).await + { + for (purl, pkg) in packages { + all_packages.entry(purl).or_insert(pkg.path); + } + } + } + } + Err(e) => { + if !silent { + eprintln!("Failed to find Cargo crates: {e}"); + } + } + } + } + } + + all_packages +} diff --git a/crates/socket-patch-cli/src/main.rs b/crates/socket-patch-cli/src/main.rs index 433ebd2..b7a9ee7 100644 --- a/crates/socket-patch-cli/src/main.rs +++ b/crates/socket-patch-cli/src/main.rs @@ -1,4 +1,5 @@ mod commands; +mod ecosystem_dispatch; use clap::{Parser, Subcommand}; diff --git a/crates/socket-patch-cli/tests/e2e_cargo.rs b/crates/socket-patch-cli/tests/e2e_cargo.rs new file mode 100644 index 0000000..c4be5bb --- /dev/null +++ b/crates/socket-patch-cli/tests/e2e_cargo.rs @@ -0,0 +1,112 @@ +#![cfg(feature = "cargo")] +//! End-to-end tests for the Cargo/Rust crate patching lifecycle. +//! +//! These tests exercise crawling against a temporary directory with a fake +//! Cargo registry layout. They do **not** require network access or a real +//! Cargo installation. +//! +//! # Running +//! ```sh +//! cargo test -p socket-patch-cli --features cargo --test e2e_cargo +//! ``` + +use std::path::PathBuf; +use std::process::{Command, Output}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +fn binary() -> PathBuf { + env!("CARGO_BIN_EXE_socket-patch").into() +} + +fn run(args: &[&str], cwd: &std::path::Path) -> Output { + Command::new(binary()) + .args(args) + .current_dir(cwd) + .env("CARGO_HOME", cwd.join(".cargo")) + .output() + .expect("Failed to run socket-patch binary") +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +/// Verify that `socket-patch scan` discovers crates in a fake registry layout. +#[test] +fn scan_discovers_fake_registry_crates() { + let dir = tempfile::tempdir().unwrap(); + + // Set up a fake CARGO_HOME/registry/src/index.crates.io-xxx/ structure + let index_dir = dir + .path() + .join(".cargo") + .join("registry") + .join("src") + .join("index.crates.io-test"); + + // Create serde-1.0.200 + let serde_dir = index_dir.join("serde-1.0.200"); + std::fs::create_dir_all(&serde_dir).unwrap(); + std::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .unwrap(); + + // Create tokio-1.38.0 + let tokio_dir = index_dir.join("tokio-1.38.0"); + std::fs::create_dir_all(&tokio_dir).unwrap(); + std::fs::write( + tokio_dir.join("Cargo.toml"), + "[package]\nname = \"tokio\"\nversion = \"1.38.0\"\n", + ) + .unwrap(); + + // Run scan (will fail to connect to API, but we just check discovery) + let output = run(&["scan", "--cwd", dir.path().to_str().unwrap()], dir.path()); + let stderr = String::from_utf8_lossy(&output.stderr); + let stdout = String::from_utf8_lossy(&output.stdout); + let combined = format!("{stdout}{stderr}"); + + // Should discover the crates (output mentions "Found X packages") + assert!( + combined.contains("Found") || combined.contains("packages"), + "Expected scan to discover crate packages, got:\n{combined}" + ); +} + +/// Verify that `socket-patch scan` discovers crates in a vendor layout. +#[test] +fn scan_discovers_vendor_crates() { + let dir = tempfile::tempdir().unwrap(); + + // Set up vendor directory + let vendor_dir = dir.path().join("vendor"); + + let serde_dir = vendor_dir.join("serde"); + std::fs::create_dir_all(&serde_dir).unwrap(); + std::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .unwrap(); + + // Run scan with JSON output to avoid API calls + let output = run( + &["scan", "--json", "--cwd", dir.path().to_str().unwrap()], + dir.path(), + ); + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + + // JSON output should show scannedPackages >= 1 (the vendor crate) + // or at minimum the scan should report finding packages + let combined = format!("{stdout}{stderr}"); + assert!( + combined.contains("scannedPackages") || combined.contains("Found"), + "Expected scan output, got:\n{combined}" + ); +} diff --git a/crates/socket-patch-core/Cargo.toml b/crates/socket-patch-core/Cargo.toml index 7930beb..7e1e9ac 100644 --- a/crates/socket-patch-core/Cargo.toml +++ b/crates/socket-patch-core/Cargo.toml @@ -19,6 +19,10 @@ uuid = { workspace = true } regex = { workspace = true } once_cell = { workspace = true } +[features] +default = [] +cargo = [] + [dev-dependencies] tempfile = { workspace = true } tokio = { version = "1", features = ["full", "test-util"] } diff --git a/crates/socket-patch-core/src/crawlers/cargo_crawler.rs b/crates/socket-patch-core/src/crawlers/cargo_crawler.rs new file mode 100644 index 0000000..17ccb5c --- /dev/null +++ b/crates/socket-patch-core/src/crawlers/cargo_crawler.rs @@ -0,0 +1,641 @@ +#![cfg(feature = "cargo")] + +use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; + +use super::types::{CrawledPackage, CrawlerOptions}; + +// --------------------------------------------------------------------------- +// Cargo.toml minimal parser +// --------------------------------------------------------------------------- + +/// Parse `name` and `version` from a `Cargo.toml` `[package]` section. +/// +/// Uses a simple line-based parser — no TOML crate dependency. +/// Handles `name = "..."` and `version = "..."` within the `[package]` table. +/// Returns `None` if `version.workspace = true` or fields are missing. +pub fn parse_cargo_toml_name_version(content: &str) -> Option<(String, String)> { + let mut in_package = false; + let mut name: Option = None; + let mut version: Option = None; + + for line in content.lines() { + let trimmed = line.trim(); + + // Skip comments and empty lines + if trimmed.starts_with('#') || trimmed.is_empty() { + continue; + } + + // Track table headers + if trimmed.starts_with('[') { + if trimmed == "[package]" { + in_package = true; + } else { + // We left the [package] section + if in_package { + break; + } + } + continue; + } + + if !in_package { + continue; + } + + if let Some(val) = extract_string_value(trimmed, "name") { + name = Some(val); + } else if let Some(val) = extract_string_value(trimmed, "version") { + version = Some(val); + } else if trimmed.starts_with("version") && trimmed.contains("workspace") { + // version.workspace = true — cannot determine version from this file + return None; + } + + if name.is_some() && version.is_some() { + break; + } + } + + match (name, version) { + (Some(n), Some(v)) if !n.is_empty() && !v.is_empty() => Some((n, v)), + _ => None, + } +} + +/// Extract a quoted string value from a `key = "value"` line. +fn extract_string_value(line: &str, key: &str) -> Option { + let rest = line.strip_prefix(key)?; + let rest = rest.trim_start(); + let rest = rest.strip_prefix('=')?; + let rest = rest.trim_start(); + let rest = rest.strip_prefix('"')?; + let end = rest.find('"')?; + Some(rest[..end].to_string()) +} + +// --------------------------------------------------------------------------- +// CargoCrawler +// --------------------------------------------------------------------------- + +/// Cargo/Rust ecosystem crawler for discovering crates in the local +/// vendor directory or the Cargo registry cache (`$CARGO_HOME/registry/src/`). +pub struct CargoCrawler; + +impl CargoCrawler { + /// Create a new `CargoCrawler`. + pub fn new() -> Self { + Self + } + + // ------------------------------------------------------------------ + // Public API + // ------------------------------------------------------------------ + + /// Get crate source paths based on options. + /// + /// In local mode, checks `/vendor/` first, then falls back to + /// `$CARGO_HOME/registry/src/` index directories. + /// + /// In global mode, returns `$CARGO_HOME/registry/src/` index directories + /// (or the `--global-prefix` override). + pub async fn get_crate_source_paths( + &self, + options: &CrawlerOptions, + ) -> Result, std::io::Error> { + if options.global || options.global_prefix.is_some() { + if let Some(ref custom) = options.global_prefix { + return Ok(vec![custom.clone()]); + } + return Ok(Self::get_registry_src_paths().await); + } + + // Local mode: check vendor first + let vendor_dir = options.cwd.join("vendor"); + if is_dir(&vendor_dir).await { + return Ok(vec![vendor_dir]); + } + + // Fall back to registry cache + Ok(Self::get_registry_src_paths().await) + } + + /// Crawl all discovered crate source directories and return every + /// package found. + pub async fn crawl_all(&self, options: &CrawlerOptions) -> Vec { + let mut packages = Vec::new(); + let mut seen = HashSet::new(); + + let src_paths = self.get_crate_source_paths(options).await.unwrap_or_default(); + + for src_path in &src_paths { + let found = self.scan_crate_source(src_path, &mut seen).await; + packages.extend(found); + } + + packages + } + + /// Find specific packages by PURL inside a single crate source directory. + /// + /// Supports two layouts: + /// - **Registry**: `-/Cargo.toml` + /// - **Vendor**: `/Cargo.toml` (version verified from file contents) + pub async fn find_by_purls( + &self, + src_path: &Path, + purls: &[String], + ) -> Result, std::io::Error> { + let mut result: HashMap = HashMap::new(); + + for purl in purls { + if let Some((name, version)) = crate::utils::purl::parse_cargo_purl(purl) { + // Try registry layout: -/ + let registry_dir = src_path.join(format!("{name}-{version}")); + if self + .verify_crate_at_path(®istry_dir, name, version) + .await + { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: registry_dir, + }, + ); + continue; + } + + // Try vendor layout: / + let vendor_dir = src_path.join(name); + if self + .verify_crate_at_path(&vendor_dir, name, version) + .await + { + result.insert( + purl.clone(), + CrawledPackage { + name: name.to_string(), + version: version.to_string(), + namespace: None, + purl: purl.clone(), + path: vendor_dir, + }, + ); + } + } + } + + Ok(result) + } + + // ------------------------------------------------------------------ + // Private helpers + // ------------------------------------------------------------------ + + /// List subdirectories of `$CARGO_HOME/registry/src/`. + /// + /// Each subdirectory corresponds to a registry index + /// (e.g. `index.crates.io-6f17d22bba15001f/`). + async fn get_registry_src_paths() -> Vec { + let cargo_home = Self::cargo_home(); + let registry_src = cargo_home.join("registry").join("src"); + + let mut paths = Vec::new(); + + let mut entries = match tokio::fs::read_dir(®istry_src).await { + Ok(rd) => rd, + Err(_) => return paths, + }; + + while let Ok(Some(entry)) = entries.next_entry().await { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if ft.is_dir() { + paths.push(registry_src.join(entry.file_name())); + } + } + + paths + } + + /// Scan a crate source directory (either a registry index directory or + /// a vendor directory) and return all valid crate packages found. + async fn scan_crate_source( + &self, + src_path: &Path, + seen: &mut HashSet, + ) -> Vec { + let mut results = Vec::new(); + + let mut entries = match tokio::fs::read_dir(src_path).await { + Ok(rd) => rd, + Err(_) => return results, + }; + + let mut entry_list = Vec::new(); + while let Ok(Some(entry)) = entries.next_entry().await { + entry_list.push(entry); + } + + for entry in entry_list { + let ft = match entry.file_type().await { + Ok(ft) => ft, + Err(_) => continue, + }; + if !ft.is_dir() { + continue; + } + + let dir_name = entry.file_name(); + let dir_name_str = dir_name.to_string_lossy(); + + // Skip hidden directories + if dir_name_str.starts_with('.') { + continue; + } + + let crate_path = src_path.join(&*dir_name_str); + if let Some(pkg) = + self.read_crate_cargo_toml(&crate_path, &dir_name_str, seen).await + { + results.push(pkg); + } + } + + results + } + + /// Read `Cargo.toml` from a crate directory, returning a `CrawledPackage` + /// if valid. Falls back to parsing name+version from the directory name + /// when the Cargo.toml has `version.workspace = true`. + async fn read_crate_cargo_toml( + &self, + crate_path: &Path, + dir_name: &str, + seen: &mut HashSet, + ) -> Option { + let cargo_toml_path = crate_path.join("Cargo.toml"); + let content = tokio::fs::read_to_string(&cargo_toml_path).await.ok()?; + + let (name, version) = match parse_cargo_toml_name_version(&content) { + Some(nv) => nv, + None => { + // Fallback: parse directory name as - + Self::parse_dir_name_version(dir_name)? + } + }; + + let purl = crate::utils::purl::build_cargo_purl(&name, &version); + + if seen.contains(&purl) { + return None; + } + seen.insert(purl.clone()); + + Some(CrawledPackage { + name, + version, + namespace: None, + purl, + path: crate_path.to_path_buf(), + }) + } + + /// Verify that a crate directory contains a Cargo.toml with the expected + /// name and version. + async fn verify_crate_at_path(&self, path: &Path, name: &str, version: &str) -> bool { + let cargo_toml_path = path.join("Cargo.toml"); + let content = match tokio::fs::read_to_string(&cargo_toml_path).await { + Ok(c) => c, + Err(_) => return false, + }; + + match parse_cargo_toml_name_version(&content) { + Some((n, v)) => n == name && v == version, + None => { + // Fallback: check directory name + let dir_name = path + .file_name() + .map(|n| n.to_string_lossy().to_string()) + .unwrap_or_default(); + if let Some((parsed_name, parsed_version)) = + Self::parse_dir_name_version(&dir_name) + { + parsed_name == name && parsed_version == version + } else { + false + } + } + } + } + + /// Parse a registry directory name into (name, version). + /// + /// Registry directories follow the pattern `-`, + /// where the version is the last `-`-separated component that starts with + /// a digit (handles crate names with hyphens like `serde-json`). + fn parse_dir_name_version(dir_name: &str) -> Option<(String, String)> { + // Find the last '-' followed by a digit + let mut split_idx = None; + for (i, _) in dir_name.match_indices('-') { + if dir_name[i + 1..].starts_with(|c: char| c.is_ascii_digit()) { + split_idx = Some(i); + } + } + let idx = split_idx?; + let name = &dir_name[..idx]; + let version = &dir_name[idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name.to_string(), version.to_string())) + } + + /// Get `CARGO_HOME`, defaulting to `$HOME/.cargo`. + fn cargo_home() -> PathBuf { + if let Ok(cargo_home) = std::env::var("CARGO_HOME") { + return PathBuf::from(cargo_home); + } + let home = std::env::var("HOME") + .or_else(|_| std::env::var("USERPROFILE")) + .unwrap_or_else(|_| "~".to_string()); + PathBuf::from(home).join(".cargo") + } +} + +impl Default for CargoCrawler { + fn default() -> Self { + Self::new() + } +} + +/// Check whether a path is a directory. +async fn is_dir(path: &Path) -> bool { + tokio::fs::metadata(path) + .await + .map(|m| m.is_dir()) + .unwrap_or(false) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_cargo_toml_basic() { + let content = r#" +[package] +name = "serde" +version = "1.0.200" +edition = "2021" +"#; + let (name, version) = parse_cargo_toml_name_version(content).unwrap(); + assert_eq!(name, "serde"); + assert_eq!(version, "1.0.200"); + } + + #[test] + fn test_parse_cargo_toml_with_comments() { + let content = r#" +# This is a comment +[package] +name = "tokio" # inline comment ignored since we stop at first " +version = "1.38.0" +"#; + let (name, version) = parse_cargo_toml_name_version(content).unwrap(); + assert_eq!(name, "tokio"); + assert_eq!(version, "1.38.0"); + } + + #[test] + fn test_parse_cargo_toml_workspace_version() { + let content = r#" +[package] +name = "my-crate" +version.workspace = true +"#; + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_cargo_toml_missing_fields() { + let content = r#" +[package] +name = "incomplete" +"#; + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_cargo_toml_no_package_section() { + let content = r#" +[dependencies] +serde = "1.0" +"#; + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_cargo_toml_stops_at_next_section() { + let content = r#" +[package] +name = "foo" + +[dependencies] +version = "fake" +"#; + // Should not find version since it's under [dependencies] + assert!(parse_cargo_toml_name_version(content).is_none()); + } + + #[test] + fn test_parse_dir_name_version() { + assert_eq!( + CargoCrawler::parse_dir_name_version("serde-1.0.200"), + Some(("serde".to_string(), "1.0.200".to_string())) + ); + assert_eq!( + CargoCrawler::parse_dir_name_version("serde-json-1.0.120"), + Some(("serde-json".to_string(), "1.0.120".to_string())) + ); + assert_eq!( + CargoCrawler::parse_dir_name_version("tokio-1.38.0"), + Some(("tokio".to_string(), "1.38.0".to_string())) + ); + assert!(CargoCrawler::parse_dir_name_version("no-version-here").is_none()); + assert!(CargoCrawler::parse_dir_name_version("noversion").is_none()); + } + + #[tokio::test] + async fn test_find_by_purls_registry_layout() { + let dir = tempfile::tempdir().unwrap(); + let serde_dir = dir.path().join("serde-1.0.200"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let purls = vec![ + "pkg:cargo/serde@1.0.200".to_string(), + "pkg:cargo/tokio@1.38.0".to_string(), + ]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:cargo/serde@1.0.200")); + assert!(!result.contains_key("pkg:cargo/tokio@1.38.0")); + } + + #[tokio::test] + async fn test_find_by_purls_vendor_layout() { + let dir = tempfile::tempdir().unwrap(); + let serde_dir = dir.path().join("serde"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let purls = vec!["pkg:cargo/serde@1.0.200".to_string()]; + let result = crawler.find_by_purls(dir.path(), &purls).await.unwrap(); + + assert_eq!(result.len(), 1); + assert!(result.contains_key("pkg:cargo/serde@1.0.200")); + } + + #[tokio::test] + async fn test_crawl_all_tempdir() { + let dir = tempfile::tempdir().unwrap(); + + // Create fake crate directories + let serde_dir = dir.path().join("serde-1.0.200"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let tokio_dir = dir.path().join("tokio-1.38.0"); + tokio::fs::create_dir_all(&tokio_dir).await.unwrap(); + tokio::fs::write( + tokio_dir.join("Cargo.toml"), + "[package]\nname = \"tokio\"\nversion = \"1.38.0\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 2); + + let purls: HashSet<_> = packages.iter().map(|p| p.purl.as_str()).collect(); + assert!(purls.contains("pkg:cargo/serde@1.0.200")); + assert!(purls.contains("pkg:cargo/tokio@1.38.0")); + } + + #[tokio::test] + async fn test_crawl_all_deduplication() { + let dir = tempfile::tempdir().unwrap(); + + // Create two directories that would resolve to the same PURL + let dir1 = dir.path().join("serde-1.0.200"); + tokio::fs::create_dir_all(&dir1).await.unwrap(); + tokio::fs::write( + dir1.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + // This would be found if we scan the parent twice + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:cargo/serde@1.0.200"); + } + + #[tokio::test] + async fn test_crawl_workspace_version_fallback() { + let dir = tempfile::tempdir().unwrap(); + + // Create a crate with workspace version — should fall back to dir name parsing + let crate_dir = dir.path().join("my-crate-0.5.0"); + tokio::fs::create_dir_all(&crate_dir).await.unwrap(); + tokio::fs::write( + crate_dir.join("Cargo.toml"), + "[package]\nname = \"my-crate\"\nversion.workspace = true\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: Some(dir.path().to_path_buf()), + batch_size: 100, + }; + + let packages = crawler.crawl_all(&options).await; + assert_eq!(packages.len(), 1); + assert_eq!(packages[0].purl, "pkg:cargo/my-crate@0.5.0"); + } + + #[tokio::test] + async fn test_vendor_layout_via_get_crate_source_paths() { + let dir = tempfile::tempdir().unwrap(); + let vendor = dir.path().join("vendor"); + tokio::fs::create_dir_all(&vendor).await.unwrap(); + + let serde_dir = vendor.join("serde"); + tokio::fs::create_dir_all(&serde_dir).await.unwrap(); + tokio::fs::write( + serde_dir.join("Cargo.toml"), + "[package]\nname = \"serde\"\nversion = \"1.0.200\"\n", + ) + .await + .unwrap(); + + let crawler = CargoCrawler::new(); + let options = CrawlerOptions { + cwd: dir.path().to_path_buf(), + global: false, + global_prefix: None, + batch_size: 100, + }; + + let paths = crawler.get_crate_source_paths(&options).await.unwrap(); + assert_eq!(paths.len(), 1); + assert_eq!(paths[0], vendor); + } +} diff --git a/crates/socket-patch-core/src/crawlers/mod.rs b/crates/socket-patch-core/src/crawlers/mod.rs index 8c33de0..d720db7 100644 --- a/crates/socket-patch-core/src/crawlers/mod.rs +++ b/crates/socket-patch-core/src/crawlers/mod.rs @@ -1,7 +1,11 @@ pub mod npm_crawler; pub mod python_crawler; pub mod types; +#[cfg(feature = "cargo")] +pub mod cargo_crawler; pub use npm_crawler::NpmCrawler; pub use python_crawler::PythonCrawler; pub use types::*; +#[cfg(feature = "cargo")] +pub use cargo_crawler::CargoCrawler; diff --git a/crates/socket-patch-core/src/crawlers/types.rs b/crates/socket-patch-core/src/crawlers/types.rs index 44489a7..1eef452 100644 --- a/crates/socket-patch-core/src/crawlers/types.rs +++ b/crates/socket-patch-core/src/crawlers/types.rs @@ -1,5 +1,71 @@ use std::path::PathBuf; +/// Identifies a supported package ecosystem. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum Ecosystem { + Npm, + Pypi, + #[cfg(feature = "cargo")] + Cargo, +} + +impl Ecosystem { + /// All enabled ecosystems. + pub fn all() -> &'static [Ecosystem] { + &[ + Ecosystem::Npm, + Ecosystem::Pypi, + #[cfg(feature = "cargo")] + Ecosystem::Cargo, + ] + } + + /// Match a PURL string to its ecosystem. + pub fn from_purl(purl: &str) -> Option { + #[cfg(feature = "cargo")] + if purl.starts_with("pkg:cargo/") { + return Some(Ecosystem::Cargo); + } + if purl.starts_with("pkg:npm/") { + return Some(Ecosystem::Npm) + } else if purl.starts_with("pkg:pypi/") { + return Some(Ecosystem::Pypi) + } else { + None + } + } + + /// The PURL prefix for this ecosystem (e.g. `"pkg:npm/"`). + pub fn purl_prefix(&self) -> &'static str { + match self { + Ecosystem::Npm => "pkg:npm/", + Ecosystem::Pypi => "pkg:pypi/", + #[cfg(feature = "cargo")] + Ecosystem::Cargo => "pkg:cargo/", + } + } + + /// Name used in the `--ecosystems` CLI flag (e.g. `"npm"`, `"pypi"`, `"cargo"`). + pub fn cli_name(&self) -> &'static str { + match self { + Ecosystem::Npm => "npm", + Ecosystem::Pypi => "pypi", + #[cfg(feature = "cargo")] + Ecosystem::Cargo => "cargo", + } + } + + /// Human-readable name for user-facing messages. + pub fn display_name(&self) -> &'static str { + match self { + Ecosystem::Npm => "npm", + Ecosystem::Pypi => "python", + #[cfg(feature = "cargo")] + Ecosystem::Cargo => "cargo", + } + } +} + /// Represents a package discovered during crawling. #[derive(Debug, Clone)] pub struct CrawledPackage { @@ -20,9 +86,9 @@ pub struct CrawledPackage { pub struct CrawlerOptions { /// Working directory to start from. pub cwd: PathBuf, - /// Use global packages instead of local node_modules. + /// Use global packages instead of local packages. pub global: bool, - /// Custom path to global node_modules (overrides auto-detection). + /// Custom path to global package directory (overrides auto-detection). pub global_prefix: Option, /// Batch size for yielding packages (default: 100). pub batch_size: usize, @@ -38,3 +104,78 @@ impl Default for CrawlerOptions { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_from_purl_npm() { + assert_eq!( + Ecosystem::from_purl("pkg:npm/lodash@4.17.21"), + Some(Ecosystem::Npm) + ); + assert_eq!( + Ecosystem::from_purl("pkg:npm/@types/node@20.0.0"), + Some(Ecosystem::Npm) + ); + } + + #[test] + fn test_from_purl_pypi() { + assert_eq!( + Ecosystem::from_purl("pkg:pypi/requests@2.28.0"), + Some(Ecosystem::Pypi) + ); + } + + #[test] + fn test_from_purl_unknown() { + assert_eq!(Ecosystem::from_purl("pkg:unknown/foo@1.0"), None); + assert_eq!(Ecosystem::from_purl("not-a-purl"), None); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_from_purl_cargo() { + assert_eq!( + Ecosystem::from_purl("pkg:cargo/serde@1.0.200"), + Some(Ecosystem::Cargo) + ); + } + + #[test] + fn test_all_count() { + let all = Ecosystem::all(); + #[cfg(not(feature = "cargo"))] + assert_eq!(all.len(), 2); + #[cfg(feature = "cargo")] + assert_eq!(all.len(), 3); + } + + #[test] + fn test_cli_name() { + assert_eq!(Ecosystem::Npm.cli_name(), "npm"); + assert_eq!(Ecosystem::Pypi.cli_name(), "pypi"); + } + + #[test] + fn test_display_name() { + assert_eq!(Ecosystem::Npm.display_name(), "npm"); + assert_eq!(Ecosystem::Pypi.display_name(), "python"); + } + + #[test] + fn test_purl_prefix() { + assert_eq!(Ecosystem::Npm.purl_prefix(), "pkg:npm/"); + assert_eq!(Ecosystem::Pypi.purl_prefix(), "pkg:pypi/"); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_cargo_properties() { + assert_eq!(Ecosystem::Cargo.cli_name(), "cargo"); + assert_eq!(Ecosystem::Cargo.display_name(), "cargo"); + assert_eq!(Ecosystem::Cargo.purl_prefix(), "pkg:cargo/"); + } +} diff --git a/crates/socket-patch-core/src/utils/purl.rs b/crates/socket-patch-core/src/utils/purl.rs index b073981..b65152a 100644 --- a/crates/socket-patch-core/src/utils/purl.rs +++ b/crates/socket-patch-core/src/utils/purl.rs @@ -64,8 +64,36 @@ pub fn parse_npm_purl(purl: &str) -> Option<(Option<&str>, &str, &str)> { } } +/// Check if a PURL is a Cargo/Rust crate. +#[cfg(feature = "cargo")] +pub fn is_cargo_purl(purl: &str) -> bool { + purl.starts_with("pkg:cargo/") +} + +/// Parse a Cargo PURL to extract name and version. +/// +/// e.g., `"pkg:cargo/serde@1.0.200"` -> `Some(("serde", "1.0.200"))` +#[cfg(feature = "cargo")] +pub fn parse_cargo_purl(purl: &str) -> Option<(&str, &str)> { + let base = strip_purl_qualifiers(purl); + let rest = base.strip_prefix("pkg:cargo/")?; + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + Some((name, version)) +} + +/// Build a Cargo PURL from components. +#[cfg(feature = "cargo")] +pub fn build_cargo_purl(name: &str, version: &str) -> String { + format!("pkg:cargo/{name}@{version}") +} + /// Parse a PURL into ecosystem, package directory path, and version. -/// Supports both npm and pypi PURLs. +/// Supports npm, pypi, and (with `cargo` feature) cargo PURLs. pub fn parse_purl(purl: &str) -> Option<(&str, String, &str)> { let base = strip_purl_qualifiers(purl); if let Some(rest) = base.strip_prefix("pkg:npm/") { @@ -85,6 +113,16 @@ pub fn parse_purl(purl: &str) -> Option<(&str, String, &str)> { } Some(("pypi", name.to_string(), version)) } else { + #[cfg(feature = "cargo")] + if let Some(rest) = base.strip_prefix("pkg:cargo/") { + let at_idx = rest.rfind('@')?; + let name = &rest[..at_idx]; + let version = &rest[at_idx + 1..]; + if name.is_empty() || version.is_empty() { + return None; + } + return Some(("cargo", name.to_string(), version)); + } None } } @@ -208,4 +246,55 @@ mod tests { "pkg:pypi/requests@2.28.0" ); } + + #[cfg(feature = "cargo")] + #[test] + fn test_is_cargo_purl() { + assert!(is_cargo_purl("pkg:cargo/serde@1.0.200")); + assert!(!is_cargo_purl("pkg:npm/lodash@4.17.21")); + assert!(!is_cargo_purl("pkg:pypi/requests@2.28.0")); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_parse_cargo_purl() { + assert_eq!( + parse_cargo_purl("pkg:cargo/serde@1.0.200"), + Some(("serde", "1.0.200")) + ); + assert_eq!( + parse_cargo_purl("pkg:cargo/serde_json@1.0.120"), + Some(("serde_json", "1.0.120")) + ); + assert_eq!(parse_cargo_purl("pkg:npm/lodash@4.17.21"), None); + assert_eq!(parse_cargo_purl("pkg:cargo/@1.0.0"), None); + assert_eq!(parse_cargo_purl("pkg:cargo/serde@"), None); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_build_cargo_purl() { + assert_eq!( + build_cargo_purl("serde", "1.0.200"), + "pkg:cargo/serde@1.0.200" + ); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_cargo_purl_round_trip() { + let purl = build_cargo_purl("tokio", "1.38.0"); + let (name, version) = parse_cargo_purl(&purl).unwrap(); + assert_eq!(name, "tokio"); + assert_eq!(version, "1.38.0"); + } + + #[cfg(feature = "cargo")] + #[test] + fn test_parse_purl_cargo() { + let (eco, dir, ver) = parse_purl("pkg:cargo/serde@1.0.200").unwrap(); + assert_eq!(eco, "cargo"); + assert_eq!(dir, "serde"); + assert_eq!(ver, "1.0.200"); + } }