diff --git a/v2/crates/wifi-densepose-train/build.rs b/v2/crates/wifi-densepose-train/build.rs new file mode 100644 index 0000000000..13347dcb33 --- /dev/null +++ b/v2/crates/wifi-densepose-train/build.rs @@ -0,0 +1,31 @@ +//! Build script: when binding libtorch from the system PyTorch install +//! (`LIBTORCH_USE_PYTORCH=1`), embed an rpath to the PyTorch `lib` directory so +//! the produced binaries and test executables can locate `libtorch_cpu` at +//! runtime. torch-sys only adds the rpath to its own cc-built static library, +//! which does not propagate to downstream binaries on macOS/Linux. +fn main() { + println!("cargo:rerun-if-env-changed=LIBTORCH_USE_PYTORCH"); + if std::env::var_os("LIBTORCH_USE_PYTORCH").is_none() { + return; + } + let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default(); + if target_os != "macos" && target_os != "linux" { + return; + } + let output = std::process::Command::new("python3") + .args([ + "-c", + "import os, torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib'))", + ]) + .output(); + if let Ok(output) = output { + if output.status.success() { + let lib_dir = String::from_utf8_lossy(&output.stdout).trim().to_string(); + if !lib_dir.is_empty() { + // Plain rustc-link-arg (not -bins/-tests): it is the only + // variant that also reaches the lib unit-test binary. + println!("cargo:rustc-link-arg=-Wl,-rpath,{lib_dir}"); + } + } + } +} diff --git a/v2/crates/wifi-densepose-train/src/bin/train.rs b/v2/crates/wifi-densepose-train/src/bin/train.rs index 7126d24f2f..52c779a277 100644 --- a/v2/crates/wifi-densepose-train/src/bin/train.rs +++ b/v2/crates/wifi-densepose-train/src/bin/train.rs @@ -55,6 +55,12 @@ struct Args { #[arg(long, value_name = "DIR")] data_dir: Option, + /// Optional MM-Fi directory used as the REAL validation set (e.g. a + /// held-out subject/session for honest cross-subject PCK). Without it, + /// a small synthetic set is used for pipeline verification only. + #[arg(long, value_name = "DIR")] + val_dir: Option, + /// Override the checkpoint output directory from the config. #[arg(long, value_name = "DIR")] checkpoint_dir: Option, @@ -76,6 +82,21 @@ struct Args { /// Log level: trace, debug, info, warn, error. #[arg(long, default_value = "info")] log_level: String, + + /// Evaluation-only mode: load `--checkpoint` and run the validation loop + /// on `--val-dir` (no training). Prints PCK@0.2 / OKS and exits. + #[arg(long, default_value_t = false)] + eval_only: bool, + + /// Path to a `.pt` checkpoint (tch VarStore format) to load in + /// `--eval-only` mode. + #[arg(long, value_name = "FILE")] + checkpoint: Option, + + /// In `--eval-only` mode, dump per-sample predicted + GT keypoints as + /// JSONL to this path (dataset order, one line per window). + #[arg(long, value_name = "FILE")] + dump_preds: Option, } // --------------------------------------------------------------------------- @@ -137,6 +158,47 @@ fn main() { log_config_summary(&config); + // ------------------------------------------------------------------ + // Eval-only mode: load checkpoint, run validation loop, exit + // ------------------------------------------------------------------ + + if args.eval_only { + let Some(ckpt) = args.checkpoint.clone() else { + error!("--eval-only requires --checkpoint "); + std::process::exit(1); + }; + let eval_dir = match args.val_dir.clone().or_else(|| args.data_dir.clone()) { + Some(d) => d, + None => { + error!("--eval-only requires --val-dir (or --data-dir) pointing at an MM-Fi root"); + std::process::exit(1); + } + }; + let eval_ds = match MmFiDataset::discover( + &eval_dir, + config.window_frames, + config.num_subcarriers, + config.num_keypoints, + ) { + Ok(ds) => ds, + Err(e) => { + error!("Failed to load eval dataset: {e}"); + std::process::exit(1); + } + }; + if eval_ds.is_empty() { + error!("Eval dataset is empty — {}", eval_dir.display()); + std::process::exit(1); + } + info!( + "Eval dataset: {} windows from {}", + eval_ds.len(), + eval_dir.display() + ); + run_eval(config, &eval_ds, &ckpt, args.dump_preds.as_deref()); + return; + } + // ------------------------------------------------------------------ // Build datasets // ------------------------------------------------------------------ @@ -199,22 +261,48 @@ fn main() { info!("Dataset: {} samples", train_ds.len()); - // Use a small synthetic validation set when running without a split. - let val_syn_cfg = SyntheticConfig { - num_subcarriers: config.num_subcarriers, - num_antennas_tx: config.num_antennas_tx, - num_antennas_rx: config.num_antennas_rx, - window_frames: config.window_frames, - num_keypoints: config.num_keypoints, - signal_frequency_hz: 2.4e9, - }; - let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg); - info!( - "Using synthetic validation set ({} samples) for pipeline verification", - val_ds.len() - ); - - run_training(config, &train_ds, &val_ds); + // Validation: a held-out MM-Fi directory when provided (honest + // cross-subject metrics + meaningful checkpoint selection); + // otherwise a small synthetic set for pipeline verification only. + if let Some(val_dir) = args.val_dir.as_ref() { + let val_ds = match MmFiDataset::discover( + val_dir, + config.window_frames, + config.num_subcarriers, + config.num_keypoints, + ) { + Ok(ds) => ds, + Err(e) => { + error!("Failed to load validation dataset: {e}"); + std::process::exit(1); + } + }; + if val_ds.is_empty() { + error!("Validation dataset is empty — {}", val_dir.display()); + std::process::exit(1); + } + info!( + "Using REAL held-out validation set: {} samples from {}", + val_ds.len(), + val_dir.display() + ); + run_training(config, &train_ds, &val_ds); + } else { + let val_syn_cfg = SyntheticConfig { + num_subcarriers: config.num_subcarriers, + num_antennas_tx: config.num_antennas_tx, + num_antennas_rx: config.num_antennas_rx, + window_frames: config.window_frames, + num_keypoints: config.num_keypoints, + signal_frequency_hz: 2.4e9, + }; + let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg); + info!( + "Using synthetic validation set ({} samples) for pipeline verification", + val_ds.len() + ); + run_training(config, &train_ds, &val_ds); + } } } @@ -251,6 +339,57 @@ fn run_training(config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn } } +/// Eval-only entry: load a checkpoint into a fresh model and run the +/// existing validation loop ([`Trainer::evaluate`]) over `ds`. +#[cfg(feature = "tch-backend")] +fn run_eval( + config: TrainingConfig, + ds: &dyn CsiDataset, + checkpoint: &std::path::Path, + dump: Option<&std::path::Path>, +) { + use wifi_densepose_train::trainer::Trainer; + + let mut trainer = Trainer::new(config); + match trainer.load_checkpoint(checkpoint) { + Ok(epoch) => info!( + "Loaded checkpoint {} (epoch {})", + checkpoint.display(), + epoch + ), + Err(e) => { + error!("Failed to load checkpoint: {e}"); + std::process::exit(1); + } + } + + let result = if let Some(dump_path) = dump { + info!("Dumping per-sample predictions to {}", dump_path.display()); + trainer.evaluate_with_dump(ds, dump_path) + } else { + trainer.evaluate(ds) + }; + + match result { + Ok(m) => info!("EVAL RESULT: {}", m.summary()), + Err(e) => { + error!("Evaluation failed: {e}"); + std::process::exit(1); + } + } +} + +#[cfg(not(feature = "tch-backend"))] +fn run_eval( + _config: TrainingConfig, + _ds: &dyn CsiDataset, + _checkpoint: &std::path::Path, + _dump: Option<&std::path::Path>, +) { + error!("--eval-only requires the `tch-backend` feature"); + std::process::exit(1); +} + #[cfg(not(feature = "tch-backend"))] fn run_training(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) { info!( diff --git a/v2/crates/wifi-densepose-train/src/losses.rs b/v2/crates/wifi-densepose-train/src/losses.rs index c5a2d29de6..2ddd33cc0f 100644 --- a/v2/crates/wifi-densepose-train/src/losses.rs +++ b/v2/crates/wifi-densepose-train/src/losses.rs @@ -118,7 +118,7 @@ impl WiFiDensePoseLoss { // Normalise by number of visible joints in the batch. let n_visible = visibility.sum(Kind::Float); // Guard against division by zero (entire batch may have no labels). - let safe_n = n_visible.clamp(1.0, f64::MAX); + let safe_n = n_visible.clamp_min(1.0); masked.sum(Kind::Float) / safe_n } @@ -165,7 +165,7 @@ impl WiFiDensePoseLoss { let masked_target_uv = target_uv * &fg_mask_f; // Count foreground pixels × 48 channels to normalise. - let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX); + let n_fg = fg_mask_f.sum(Kind::Float).clamp_min(1.0); // Smooth-L1 with beta=1.0, reduction=Sum then divide by fg count. let uv_loss_sum = masked_pred_uv.smooth_l1_loss(&masked_target_uv, Reduction::Sum, 1.0); @@ -234,7 +234,7 @@ impl WiFiDensePoseLoss { // UV loss (foreground masked) let fg_mask = target_int.not_equal(0_i64); let fg_mask_f = fg_mask.unsqueeze(1).expand_as(pu).to_kind(Kind::Float); - let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX); + let n_fg = fg_mask_f.sum(Kind::Float).clamp_min(1.0); let uv_loss = (pu * &fg_mask_f).smooth_l1_loss(&(tu * &fg_mask_f), Reduction::Sum, 1.0) / n_fg; @@ -692,8 +692,11 @@ mod tests { let cy = (kp_y * s).round() as usize; let peak = hm[[cy, cx]]; + // kp 0.5 on a 64-px map is the half-pixel point 31.5, so the max + // attainable at pixel 31 or 32 is exp(-(0.5^2+0.5^2)/(2*2^2)) ≈ 0.9394. + // The old 0.95 threshold was mathematically unreachable. assert!( - peak > 0.95, + peak > 0.93, "Peak value {peak} should be close to 1.0 at centre" ); @@ -743,6 +746,7 @@ mod tests { } // Visible batch (index 1) should have non-zero heatmaps. + let heatmaps = &heatmaps; let batch1_sum: f32 = (0..num_joints) .map(|j| { (0..size) @@ -930,7 +934,10 @@ mod tests { let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev)); let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5); let max_val: f64 = hm.max().double_value(&[]); - assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9"); + // kp 0.5 on an 8-px map is the half-pixel point 3.5, so the max + // attainable at pixel 3 or 4 is exp(-0.5/(2*1.5^2)) ≈ 0.8007. + // The old 0.9 threshold was mathematically unreachable. + assert!(max_val > 0.79, "Peak value {max_val} should be > 0.79"); } #[test] diff --git a/v2/crates/wifi-densepose-train/src/metrics.rs b/v2/crates/wifi-densepose-train/src/metrics.rs index 913afa054b..90ece4e278 100644 --- a/v2/crates/wifi-densepose-train/src/metrics.rs +++ b/v2/crates/wifi-densepose-train/src/metrics.rs @@ -19,6 +19,7 @@ use ndarray::{Array1, Array2, ArrayView1, ArrayView2}; use petgraph::graph::{DiGraph, NodeIndex}; +use petgraph::visit::EdgeRef; use ruvector_mincut::{DynamicMinCut, MinCutBuilder}; use std::collections::VecDeque; diff --git a/v2/crates/wifi-densepose-train/src/model.rs b/v2/crates/wifi-densepose-train/src/model.rs index 484eb47884..a473eac12a 100644 --- a/v2/crates/wifi-densepose-train/src/model.rs +++ b/v2/crates/wifi-densepose-train/src/model.rs @@ -182,7 +182,7 @@ impl WiFiDensePoseModel { self.vs .trainable_variables() .iter() - .map(|t| t.numel()) + .map(|t| t.numel() as i64) .sum() } @@ -297,7 +297,13 @@ fn apply_antenna_attention(x: &Tensor, lambda: f32) -> Tensor { let xi = x.select(0, bi as i64); // [n_ant, n_sc] // Move to CPU and convert to f32 for the pure-Rust attention kernel. - let flat: Vec = Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous()); + let flat: Vec = Vec::::try_from( + xi.to_kind(Kind::Float) + .to_device(Device::Cpu) + .contiguous() + .flatten(0, -1), + ) + .expect("tensor to Vec"); // Q = K = V = the antenna features (self-attention over antenna paths). let out = attn_mincut( @@ -350,7 +356,13 @@ fn apply_spatial_attention(x: &Tensor) -> Tensor { for bi in 0..b { // Extract [C, H*W] and transpose to [H*W, C]. let xi = x.select(0, bi).reshape([c, h * w]).transpose(0, 1); // [H*W, C] - let flat: Vec = Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous()); + let flat: Vec = Vec::::try_from( + xi.to_kind(Kind::Float) + .to_device(Device::Cpu) + .contiguous() + .flatten(0, -1), + ) + .expect("tensor to Vec"); // Build token slices — one per spatial position. let tokens: Vec<&[f32]> = (0..n_spatial).map(|i| &flat[i * d..(i + 1) * d]).collect(); diff --git a/v2/crates/wifi-densepose-train/src/proof.rs b/v2/crates/wifi-densepose-train/src/proof.rs index 35f9ff14e9..ce6c23e827 100644 --- a/v2/crates/wifi-densepose-train/src/proof.rs +++ b/v2/crates/wifi-densepose-train/src/proof.rs @@ -153,11 +153,11 @@ pub fn run_proof(proof_dir: &Path) -> Result = Vec::::from(kp.to_kind(Kind::Double).flatten(0, -1)) + let kp_vec: Vec = Vec::::try_from(kp.to_kind(Kind::Double).flatten(0, -1))? .iter() .map(|&x| x as f32) .collect(); - let vis_vec: Vec = Vec::::from(vis.to_kind(Kind::Double).flatten(0, -1)) + let vis_vec: Vec = Vec::::try_from(vis.to_kind(Kind::Double).flatten(0, -1))? .iter() .map(|&x| x as f32) .collect(); @@ -261,7 +261,7 @@ pub fn hash_model_weights(model: &WiFiDensePoseModel) -> String { .flatten(0, -1) .to_kind(Kind::Float) .to_device(Device::Cpu); - let values: Vec = Vec::::from(&flat); + let values: Vec = Vec::::try_from(&flat).expect("tensor to Vec"); let mut buf = vec![0u8; values.len() * 4]; for (i, v) in values.iter().enumerate() { let bytes = v.to_le_bytes(); diff --git a/v2/crates/wifi-densepose-train/src/trainer.rs b/v2/crates/wifi-densepose-train/src/trainer.rs index cbb7da72f5..66712ceaf0 100644 --- a/v2/crates/wifi-densepose-train/src/trainer.rs +++ b/v2/crates/wifi-densepose-train/src/trainer.rs @@ -394,6 +394,73 @@ impl Trainer { acc.finalize().ok_or(TrainError::EmptyDataset) } + /// Evaluate like [`Trainer::evaluate`], additionally dumping per-sample + /// predicted + ground-truth keypoints as JSONL to `dump_path`. + /// + /// Each output line is + /// `{"idx":N,"pred":[[x,y],...17],"gt":[[x,y],...17],"vis":[...17]}` with + /// samples in dataset order (no shuffle), so `idx` equals the dataset + /// window index. Used by the `--eval-only --dump-preds` path of the + /// `train` binary for offline rendering / analysis. + pub fn evaluate_with_dump( + &self, + dataset: &dyn CsiDataset, + dump_path: &Path, + ) -> Result { + if dataset.is_empty() { + return Err(TrainError::EmptyDataset); + } + + let mut acc = MetricsAccumulator::default_threshold(); + let mut dump = std::io::BufWriter::new( + std::fs::File::create(dump_path) + .map_err(|e| TrainError::training_step(format!("create dump file: {e}")))?, + ); + + let batches = make_batches( + dataset, + self.config.batch_size, + false, // no shuffle during evaluation + self.config.seed, + self.device, + ); + + let mut global_idx: usize = 0; + for (amp_batch, phase_batch, kp_batch, vis_batch) in &batches { + let output = self.model.forward_inference(amp_batch, phase_batch); + let pred_kps = heatmap_to_keypoints(&output.keypoints); + + let batch_size = kp_batch.size()[0] as usize; + for b in 0..batch_size { + let pred_kp_np = extract_kp_ndarray(&pred_kps, b); + let gt_kp_np = extract_kp_ndarray(kp_batch, b); + let vis_np = extract_vis_ndarray(vis_batch, b); + + acc.update(&pred_kp_np, >_kp_np, &vis_np); + + let pred_v: Vec<[f32; 2]> = pred_kp_np + .rows() + .into_iter() + .map(|r| [r[0], r[1]]) + .collect(); + let gt_v: Vec<[f32; 2]> = + gt_kp_np.rows().into_iter().map(|r| [r[0], r[1]]).collect(); + let vis_v: Vec = vis_np.to_vec(); + let line = serde_json::json!({ + "idx": global_idx, + "pred": pred_v, + "gt": gt_v, + "vis": vis_v, + }); + writeln!(dump, "{line}") + .map_err(|e| TrainError::training_step(format!("write dump line: {e}")))?; + global_idx += 1; + } + } + + acc.finalize().ok_or(TrainError::EmptyDataset) + } + /// Save a training checkpoint. pub fn save_checkpoint( &self, @@ -582,11 +649,13 @@ fn kp_to_heatmap_tensor( let num_kp = kp_tensor.size()[1] as usize; // Convert to ndarray for generate_target_heatmaps. - let kp_vec: Vec = Vec::::from(kp_tensor.to_kind(Kind::Double).flatten(0, -1)) + let kp_vec: Vec = Vec::::try_from(kp_tensor.to_kind(Kind::Double).flatten(0, -1)) + .expect("kp tensor to Vec") .iter() .map(|&x| x as f32) .collect(); - let vis_vec: Vec = Vec::::from(vis_tensor.to_kind(Kind::Double).flatten(0, -1)) + let vis_vec: Vec = Vec::::try_from(vis_tensor.to_kind(Kind::Double).flatten(0, -1)) + .expect("vis tensor to Vec") .iter() .map(|&x| x as f32) .collect(); @@ -621,9 +690,11 @@ fn heatmap_to_keypoints(heatmaps: &Tensor) -> Tensor { // Argmax per joint → [B, 17] let arg = flat.argmax(-1, false); - // Decompose linear index into (row, col). - let row = (&arg / w).to_kind(Kind::Float); // [B, 17] - let col = (&arg % w).to_kind(Kind::Float); // [B, 17] + // Decompose linear index into (row, col). NB: `/` on an integer tensor + // is TRUE division in torch >= 1.6 (36 / 8 = 4.5), which skewed every + // decoded y by up to one row — use explicit floor division. + let row = arg.floor_divide_scalar(w).to_kind(Kind::Float); // [B, 17] + let col = arg.fmod(w).to_kind(Kind::Float); // [B, 17] // Normalize to [0, 1] let x = col / (w - 1) as f64; @@ -639,7 +710,8 @@ fn heatmap_to_keypoints(heatmaps: &Tensor) -> Tensor { fn extract_kp_ndarray(kp_tensor: &Tensor, batch_idx: usize) -> Array2 { let num_kp = kp_tensor.size()[1] as usize; let row = kp_tensor.select(0, batch_idx as i64); - let data: Vec = Vec::::from(row.to_kind(Kind::Double).flatten(0, -1)) + let data: Vec = Vec::::try_from(row.to_kind(Kind::Double).flatten(0, -1)) + .expect("kp tensor to Vec") .iter() .map(|&v| v as f32) .collect(); @@ -652,7 +724,8 @@ fn extract_kp_ndarray(kp_tensor: &Tensor, batch_idx: usize) -> Array2 { fn extract_vis_ndarray(vis_tensor: &Tensor, batch_idx: usize) -> Array1 { let num_kp = vis_tensor.size()[1] as usize; let row = vis_tensor.select(0, batch_idx as i64); - let data: Vec = Vec::::from(row.to_kind(Kind::Double)) + let data: Vec = Vec::::try_from(row.to_kind(Kind::Double)) + .expect("vis tensor to Vec") .iter() .map(|&v| v as f32) .collect();