Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions v2/crates/wifi-densepose-train/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//! Build script: when binding libtorch from the system PyTorch install
//! (`LIBTORCH_USE_PYTORCH=1`), embed an rpath to the PyTorch `lib` directory so
//! the produced binaries and test executables can locate `libtorch_cpu` at
//! runtime. torch-sys only adds the rpath to its own cc-built static library,
//! which does not propagate to downstream binaries on macOS/Linux.
fn main() {
println!("cargo:rerun-if-env-changed=LIBTORCH_USE_PYTORCH");
if std::env::var_os("LIBTORCH_USE_PYTORCH").is_none() {
return;
}
let target_os = std::env::var("CARGO_CFG_TARGET_OS").unwrap_or_default();
if target_os != "macos" && target_os != "linux" {
return;
}
let output = std::process::Command::new("python3")
.args([
"-c",
"import os, torch; print(os.path.join(os.path.dirname(torch.__file__), 'lib'))",
])
.output();
if let Ok(output) = output {
if output.status.success() {
let lib_dir = String::from_utf8_lossy(&output.stdout).trim().to_string();
if !lib_dir.is_empty() {
// Plain rustc-link-arg (not -bins/-tests): it is the only
// variant that also reaches the lib unit-test binary.
println!("cargo:rustc-link-arg=-Wl,-rpath,{lib_dir}");
}
}
}
}
171 changes: 155 additions & 16 deletions v2/crates/wifi-densepose-train/src/bin/train.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ struct Args {
#[arg(long, value_name = "DIR")]
data_dir: Option<PathBuf>,

/// Optional MM-Fi directory used as the REAL validation set (e.g. a
/// held-out subject/session for honest cross-subject PCK). Without it,
/// a small synthetic set is used for pipeline verification only.
#[arg(long, value_name = "DIR")]
val_dir: Option<PathBuf>,

/// Override the checkpoint output directory from the config.
#[arg(long, value_name = "DIR")]
checkpoint_dir: Option<PathBuf>,
Expand All @@ -76,6 +82,21 @@ struct Args {
/// Log level: trace, debug, info, warn, error.
#[arg(long, default_value = "info")]
log_level: String,

/// Evaluation-only mode: load `--checkpoint` and run the validation loop
/// on `--val-dir` (no training). Prints PCK@0.2 / OKS and exits.
#[arg(long, default_value_t = false)]
eval_only: bool,

/// Path to a `.pt` checkpoint (tch VarStore format) to load in
/// `--eval-only` mode.
#[arg(long, value_name = "FILE")]
checkpoint: Option<PathBuf>,

/// In `--eval-only` mode, dump per-sample predicted + GT keypoints as
/// JSONL to this path (dataset order, one line per window).
#[arg(long, value_name = "FILE")]
dump_preds: Option<PathBuf>,
}

// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -137,6 +158,47 @@ fn main() {

log_config_summary(&config);

// ------------------------------------------------------------------
// Eval-only mode: load checkpoint, run validation loop, exit
// ------------------------------------------------------------------

if args.eval_only {
let Some(ckpt) = args.checkpoint.clone() else {
error!("--eval-only requires --checkpoint <FILE>");
std::process::exit(1);
};
let eval_dir = match args.val_dir.clone().or_else(|| args.data_dir.clone()) {
Some(d) => d,
None => {
error!("--eval-only requires --val-dir (or --data-dir) pointing at an MM-Fi root");
std::process::exit(1);
}
};
let eval_ds = match MmFiDataset::discover(
&eval_dir,
config.window_frames,
config.num_subcarriers,
config.num_keypoints,
) {
Ok(ds) => ds,
Err(e) => {
error!("Failed to load eval dataset: {e}");
std::process::exit(1);
}
};
if eval_ds.is_empty() {
error!("Eval dataset is empty — {}", eval_dir.display());
std::process::exit(1);
}
info!(
"Eval dataset: {} windows from {}",
eval_ds.len(),
eval_dir.display()
);
run_eval(config, &eval_ds, &ckpt, args.dump_preds.as_deref());
return;
}

// ------------------------------------------------------------------
// Build datasets
// ------------------------------------------------------------------
Expand Down Expand Up @@ -199,22 +261,48 @@ fn main() {

info!("Dataset: {} samples", train_ds.len());

// Use a small synthetic validation set when running without a split.
let val_syn_cfg = SyntheticConfig {
num_subcarriers: config.num_subcarriers,
num_antennas_tx: config.num_antennas_tx,
num_antennas_rx: config.num_antennas_rx,
window_frames: config.window_frames,
num_keypoints: config.num_keypoints,
signal_frequency_hz: 2.4e9,
};
let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg);
info!(
"Using synthetic validation set ({} samples) for pipeline verification",
val_ds.len()
);

run_training(config, &train_ds, &val_ds);
// Validation: a held-out MM-Fi directory when provided (honest
// cross-subject metrics + meaningful checkpoint selection);
// otherwise a small synthetic set for pipeline verification only.
if let Some(val_dir) = args.val_dir.as_ref() {
let val_ds = match MmFiDataset::discover(
val_dir,
config.window_frames,
config.num_subcarriers,
config.num_keypoints,
) {
Ok(ds) => ds,
Err(e) => {
error!("Failed to load validation dataset: {e}");
std::process::exit(1);
}
};
if val_ds.is_empty() {
error!("Validation dataset is empty — {}", val_dir.display());
std::process::exit(1);
}
info!(
"Using REAL held-out validation set: {} samples from {}",
val_ds.len(),
val_dir.display()
);
run_training(config, &train_ds, &val_ds);
} else {
let val_syn_cfg = SyntheticConfig {
num_subcarriers: config.num_subcarriers,
num_antennas_tx: config.num_antennas_tx,
num_antennas_rx: config.num_antennas_rx,
window_frames: config.window_frames,
num_keypoints: config.num_keypoints,
signal_frequency_hz: 2.4e9,
};
let val_ds = SyntheticCsiDataset::new(config.batch_size.max(1), val_syn_cfg);
info!(
"Using synthetic validation set ({} samples) for pipeline verification",
val_ds.len()
);
run_training(config, &train_ds, &val_ds);
}
}
}

Expand Down Expand Up @@ -251,6 +339,57 @@ fn run_training(config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn
}
}

/// Eval-only entry: load a checkpoint into a fresh model and run the
/// existing validation loop ([`Trainer::evaluate`]) over `ds`.
#[cfg(feature = "tch-backend")]
fn run_eval(
config: TrainingConfig,
ds: &dyn CsiDataset,
checkpoint: &std::path::Path,
dump: Option<&std::path::Path>,
) {
use wifi_densepose_train::trainer::Trainer;

let mut trainer = Trainer::new(config);
match trainer.load_checkpoint(checkpoint) {
Ok(epoch) => info!(
"Loaded checkpoint {} (epoch {})",
checkpoint.display(),
epoch
),
Err(e) => {
error!("Failed to load checkpoint: {e}");
std::process::exit(1);
}
}

let result = if let Some(dump_path) = dump {
info!("Dumping per-sample predictions to {}", dump_path.display());
trainer.evaluate_with_dump(ds, dump_path)
} else {
trainer.evaluate(ds)
};

match result {
Ok(m) => info!("EVAL RESULT: {}", m.summary()),
Err(e) => {
error!("Evaluation failed: {e}");
std::process::exit(1);
}
}
}

#[cfg(not(feature = "tch-backend"))]
fn run_eval(
_config: TrainingConfig,
_ds: &dyn CsiDataset,
_checkpoint: &std::path::Path,
_dump: Option<&std::path::Path>,
) {
error!("--eval-only requires the `tch-backend` feature");
std::process::exit(1);
}

#[cfg(not(feature = "tch-backend"))]
fn run_training(_config: TrainingConfig, train_ds: &dyn CsiDataset, val_ds: &dyn CsiDataset) {
info!(
Expand Down
17 changes: 12 additions & 5 deletions v2/crates/wifi-densepose-train/src/losses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ impl WiFiDensePoseLoss {
// Normalise by number of visible joints in the batch.
let n_visible = visibility.sum(Kind::Float);
// Guard against division by zero (entire batch may have no labels).
let safe_n = n_visible.clamp(1.0, f64::MAX);
let safe_n = n_visible.clamp_min(1.0);

masked.sum(Kind::Float) / safe_n
}
Expand Down Expand Up @@ -165,7 +165,7 @@ impl WiFiDensePoseLoss {
let masked_target_uv = target_uv * &fg_mask_f;

// Count foreground pixels × 48 channels to normalise.
let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX);
let n_fg = fg_mask_f.sum(Kind::Float).clamp_min(1.0);

// Smooth-L1 with beta=1.0, reduction=Sum then divide by fg count.
let uv_loss_sum = masked_pred_uv.smooth_l1_loss(&masked_target_uv, Reduction::Sum, 1.0);
Expand Down Expand Up @@ -234,7 +234,7 @@ impl WiFiDensePoseLoss {
// UV loss (foreground masked)
let fg_mask = target_int.not_equal(0_i64);
let fg_mask_f = fg_mask.unsqueeze(1).expand_as(pu).to_kind(Kind::Float);
let n_fg = fg_mask_f.sum(Kind::Float).clamp(1.0, f64::MAX);
let n_fg = fg_mask_f.sum(Kind::Float).clamp_min(1.0);
let uv_loss =
(pu * &fg_mask_f).smooth_l1_loss(&(tu * &fg_mask_f), Reduction::Sum, 1.0)
/ n_fg;
Expand Down Expand Up @@ -692,8 +692,11 @@ mod tests {
let cy = (kp_y * s).round() as usize;

let peak = hm[[cy, cx]];
// kp 0.5 on a 64-px map is the half-pixel point 31.5, so the max
// attainable at pixel 31 or 32 is exp(-(0.5^2+0.5^2)/(2*2^2)) ≈ 0.9394.
// The old 0.95 threshold was mathematically unreachable.
assert!(
peak > 0.95,
peak > 0.93,
"Peak value {peak} should be close to 1.0 at centre"
);

Expand Down Expand Up @@ -743,6 +746,7 @@ mod tests {
}

// Visible batch (index 1) should have non-zero heatmaps.
let heatmaps = &heatmaps;
let batch1_sum: f32 = (0..num_joints)
.map(|j| {
(0..size)
Expand Down Expand Up @@ -930,7 +934,10 @@ mod tests {
let vis = Tensor::ones(&[1i64, 1], (Kind::Float, dev));
let hm = generate_gaussian_heatmaps(&kpts, &vis, 8, 1.5);
let max_val: f64 = hm.max().double_value(&[]);
assert!(max_val > 0.9, "Peak value {max_val} should be > 0.9");
// kp 0.5 on an 8-px map is the half-pixel point 3.5, so the max
// attainable at pixel 3 or 4 is exp(-0.5/(2*1.5^2)) ≈ 0.8007.
// The old 0.9 threshold was mathematically unreachable.
assert!(max_val > 0.79, "Peak value {max_val} should be > 0.79");
}

#[test]
Expand Down
1 change: 1 addition & 0 deletions v2/crates/wifi-densepose-train/src/metrics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use petgraph::graph::{DiGraph, NodeIndex};
use petgraph::visit::EdgeRef;
use ruvector_mincut::{DynamicMinCut, MinCutBuilder};
use std::collections::VecDeque;

Expand Down
18 changes: 15 additions & 3 deletions v2/crates/wifi-densepose-train/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ impl WiFiDensePoseModel {
self.vs
.trainable_variables()
.iter()
.map(|t| t.numel())
.map(|t| t.numel() as i64)
.sum()
}

Expand Down Expand Up @@ -297,7 +297,13 @@ fn apply_antenna_attention(x: &Tensor, lambda: f32) -> Tensor {
let xi = x.select(0, bi as i64); // [n_ant, n_sc]

// Move to CPU and convert to f32 for the pure-Rust attention kernel.
let flat: Vec<f32> = Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous());
let flat: Vec<f32> = Vec::<f32>::try_from(
xi.to_kind(Kind::Float)
.to_device(Device::Cpu)
.contiguous()
.flatten(0, -1),
)
.expect("tensor to Vec<f32>");

// Q = K = V = the antenna features (self-attention over antenna paths).
let out = attn_mincut(
Expand Down Expand Up @@ -350,7 +356,13 @@ fn apply_spatial_attention(x: &Tensor) -> Tensor {
for bi in 0..b {
// Extract [C, H*W] and transpose to [H*W, C].
let xi = x.select(0, bi).reshape([c, h * w]).transpose(0, 1); // [H*W, C]
let flat: Vec<f32> = Vec::from(xi.to_kind(Kind::Float).to_device(Device::Cpu).contiguous());
let flat: Vec<f32> = Vec::<f32>::try_from(
xi.to_kind(Kind::Float)
.to_device(Device::Cpu)
.contiguous()
.flatten(0, -1),
)
.expect("tensor to Vec<f32>");

// Build token slices — one per spatial position.
let tokens: Vec<&[f32]> = (0..n_spatial).map(|i| &flat[i * d..(i + 1) * d]).collect();
Expand Down
6 changes: 3 additions & 3 deletions v2/crates/wifi-densepose-train/src/proof.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ pub fn run_proof(proof_dir: &Path) -> Result<ProofResult, Box<dyn std::error::Er
let num_kp = kp.size()[1] as usize;
let hm_size = cfg.heatmap_size;

let kp_vec: Vec<f32> = Vec::<f64>::from(kp.to_kind(Kind::Double).flatten(0, -1))
let kp_vec: Vec<f32> = Vec::<f64>::try_from(kp.to_kind(Kind::Double).flatten(0, -1))?
.iter()
.map(|&x| x as f32)
.collect();
let vis_vec: Vec<f32> = Vec::<f64>::from(vis.to_kind(Kind::Double).flatten(0, -1))
let vis_vec: Vec<f32> = Vec::<f64>::try_from(vis.to_kind(Kind::Double).flatten(0, -1))?
.iter()
.map(|&x| x as f32)
.collect();
Expand Down Expand Up @@ -261,7 +261,7 @@ pub fn hash_model_weights(model: &WiFiDensePoseModel) -> String {
.flatten(0, -1)
.to_kind(Kind::Float)
.to_device(Device::Cpu);
let values: Vec<f32> = Vec::<f32>::from(&flat);
let values: Vec<f32> = Vec::<f32>::try_from(&flat).expect("tensor to Vec<f32>");
let mut buf = vec![0u8; values.len() * 4];
for (i, v) in values.iter().enumerate() {
let bytes = v.to_le_bytes();
Expand Down
Loading