Skip to content

fix(train): un-rot tch-backend vs workspace tch 0.24 + clamp/floor-div runtime bugs + real --val-dir validation and --eval-only mode#1014

Open
stuinfla wants to merge 1 commit into
ruvnet:mainfrom
stuinfla:fix/trainer-tch024-and-real-validation
Open

fix(train): un-rot tch-backend vs workspace tch 0.24 + clamp/floor-div runtime bugs + real --val-dir validation and --eval-only mode#1014
stuinfla wants to merge 1 commit into
ruvnet:mainfrom
stuinfla:fix/trainer-tch024-and-real-validation

Conversation

@stuinfla

Copy link
Copy Markdown

Problem

wifi-densepose-train's optional tch-backend feature does not compile against the workspace's own tch = "0.24" pin (13 errors), and even after compiling, the first real training step panics. Because mod metrics, mod model, mod trainer, etc. are all #[cfg(feature = "tch-backend")] and CI builds --no-default-features, none of this is visible to the PR gate — the feature bit-rotted silently. Full write-up with error sites: #1010.

On top of that, bin/train.rs always pairs a real --data-dir with a synthetic validation set, so on real data val_pck is noise (~0.02–0.06) and best-checkpoint selection / early stopping are driven by that noise — silently.

Root cause → fix (file:line against current main)

1. tch 0.24 API drift (compile errors):

  • Vec::<f64>::from(tensor) → conversions moved to TryFrom in tch 0.24 — model.rs:300,359, proof.rs:156,160,264, trainer.rs:585,590,643,657
  • (&tensor % i64) → scalar Rem impl dropped; use .fmod()trainer.rs:627
  • t.numel() returns usize, summed into i64model.rs:185 (as i64)
  • EdgeReference::id()/.target()/.weight() now need use petgraph::visit::EdgeRef;metrics.rs (errors at 1103–1105)

2. Runtime panic: losses.rs:121,168,237.clamp(1.0, f64::MAX): PyTorch ≥ 2.x rejects converting f64::MAX into an f32 tensor ("cannot be converted to type float without overflow"), so the first loss call panics. .clamp_min(1.0) is the exact intended semantic.

3. dyld failure with LIBTORCH_USE_PYTORCH=1 (macOS/Linux): torch-sys only rpaths its own cc-built static lib, which does not propagate to downstream binaries — train dies at launch with libtorch_cpu not found, and so does the crate's own unit-test binary under cargo test --features tch-backend. New build.rs emits cargo:rustc-link-arg=-Wl,-rpath,<torch>/lib when that env var is set (plain rustc-link-arg, because the -bins/-tests variants do not reach the lib unit-test binary — verified with otool -l); a no-op otherwise.

4. Latent decode bug exposed by the revived tests: trainer.rs:694 heatmap_to_keypoints(&arg / w) on the integer argmax tensor is TRUE division in torch ≥ 1.6 (36/8 = 4.5), so every decoded y row was skewed by up to one row. floor_divide_scalar(w) restores integer row decomposition (the unit test pins center peak (4,4) → 4/7 exactly).

5. Test-code fixes (pre-existing, tch-gated so invisible to CI):

  • losses.rs:749 — E0507: move closure inside an FnMut closure moves the captured heatmaps array on every outer iteration; shadowed with let heatmaps = &heatmaps;
  • losses.rs gaussian-peak thresholds were mathematically unreachable: kp 0.5 sits on a half-pixel, so the best attainable peak is exp(−0.5/(2σ²)) — 0.9394 for σ=2 (test demanded > 0.95) and 0.8007 for σ=1.5 (test demanded > 0.9). Relaxed to 0.93 / 0.79 with the math in comments.

6. No real-validation path: new --val-dir <DIR> loads a held-out MM-Fi directory via MmFiDataset::discover as the validation set (synthetic fallback unchanged when the flag is absent). New --eval-only --checkpoint <FILE> [--dump-preds <JSONL>] mode loads a checkpoint and runs the existing validation loop standalone (Trainer::evaluate_with_dump dumps per-sample pred+GT keypoints in dataset order for offline rendering/analysis).

Test evidence

  • cargo check -p wifi-densepose-cli -p wifi-densepose-signal -p wifi-densepose-train --no-default-features — clean (matches CI gate)
  • LIBTORCH_USE_PYTORCH=1 cargo check -p wifi-densepose-train --features tch-backend — clean against PyTorch 2.11.0 (the 3 remaining warnings — model.rs:31 unused nn::ModuleT, proof.rs:28 unused CsiDataset, trainer.rs:723 unused num_kp — are pre-existing on main and deliberately untouched)
  • cargo test -p wifi-densepose-train --no-default-features — 7/7 pass
  • LIBTORCH_USE_PYTORCH=1 cargo test -p wifi-densepose-train --features tch-backend --lib -- --test-threads=1 — 194/195 pass. Honest accounting of what still fails and why (all pre-existing, none regressions from this PR):
    • model::tests::save_and_load_roundtrip fails against PyTorch 2.11 libtorch: tch 0.24's .pt (jit _load_parameters) reload hits Expected GenericDict but got Object. Version interplay, not this code — real checkpoint reload verified via .safetensors (below). Likely passes against the libtorch generation tch 0.24 targets.
    • proof::tests::hash_model_weights_is_deterministic and generate_and_verify_hash_matches are racy under parallel test threads (other tests draw from torch's shared global RNG between the two manual_seed calls); both pass with --test-threads=1, so they are counted above as passing. A proper fix (RNG mutex) felt out of scope.
  • End-to-end on hardware: with these fixes, training on ~8.6k live-captured 256-bin HE20 CSI windows (per ESP32-C6 delivers HE20 CSI (256 bins / 242 tones) when built with IDF v5.5.2 — resolves WITNESS-LOG-110 §B1 open question; includes IDF 5.5 build fix for c6_sync_espnow.c #1005) with a held-out subject as --val-dir reached 40% cross-subject PCK@0.2 (vs 11.6% reported in ADR-150), with checkpoint selection driven by the real metric instead of synthetic noise.
  • --eval-only round-trip verified: reloading best_epoch0016_pck0.3999 (as .safetensors — see the PyTorch-2.11 .pt note above) and re-running the validation loop reproduces exactly PCK@0.2 = 0.3999 / OKS = 0.1277 (n=3191) on the original val set — the eval harness is bit-faithful to training-time validation.

Formatting note: src/ablation.rs and src/bin/aa_score_runner.rs currently fail cargo fmt --check on main; left untouched here to keep this PR surgical.

Closes #1010

🤖 Generated with claude-flow

…ixes, clamp_min + floor-div runtime bugs, rpath build.rs, real --val-dir validation and --eval-only mode (ruvnet#1010)

The optional tch-backend feature is not CI-gated (mod metrics/model/trainer/
losses/proof are cfg'd behind it), so it drifted: 13 compile errors against
the workspace's own tch = "0.24" pin, plus latent runtime bugs the revived
tests exposed. All verified by compiling, running the unit suite, and
training + evaluating end-to-end on real CSI:

- tch 0.24 API: Vec<f64>::from(tensor) -> TryFrom (model.rs, proof.rs,
  trainer.rs — 8 sites); (&tensor % i64) Rem dropped -> .fmod()
  (trainer.rs); t.numel() usize/i64 sum mismatch (model.rs:185);
  petgraph EdgeReference methods need use petgraph::visit::EdgeRef
  (metrics.rs).
- Runtime panic: losses.rs .clamp(1.0, f64::MAX) — PyTorch >= 2.x rejects
  f64::MAX -> f32 tensor conversion; first loss call panics.
  .clamp_min(1.0) is the intended semantic.
- Latent decode bug: trainer.rs heatmap_to_keypoints used (&arg / w) on the
  integer argmax tensor — TRUE division in torch >= 1.6 (36/8 = 4.5), so
  every decoded y row was skewed by up to one row. floor_divide_scalar
  restores integer row decomposition.
- macOS/Linux + LIBTORCH_USE_PYTORCH=1: binaries and the unit-test binary
  failed at dyld (libtorch_cpu not found) — torch-sys only rpaths its own
  static lib. New build.rs emits cargo:rustc-link-arg=-Wl,-rpath,<torch>/lib
  (plain variant: -bins/-tests do not reach the lib unit-test binary).
- Test fixes (pre-existing, tch-gated so invisible to CI): losses.rs:749
  E0507 move-out-of-FnMut-capture; two gaussian-peak thresholds that were
  mathematically unreachable at half-pixel keypoints (max attainable
  exp(-0.5/(2*sigma^2)) = 0.9394 / 0.8007 vs asserted 0.95 / 0.9).
- Real validation: --val-dir <MM-Fi dir> uses a held-out directory as the
  validation set instead of the always-synthetic one (which made val_pck
  noise and checkpoint selection blind on real data). New --eval-only
  --checkpoint <FILE> [--dump-preds <JSONL>] mode runs the validation loop
  standalone via Trainer::evaluate_with_dump.

Suite: 194/195 with --features tch-backend (the 1 failure is a tch-0.24 vs
PyTorch-2.11 .pt jit reload interplay, pre-existing; checkpoint reload
verified via .safetensors reproducing PCK@0.2=0.3999 exactly).

Closes ruvnet#1010

Co-Authored-By: claude-flow <ruv@ruv.net>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

1 participant